2025-11-25T22:34:18.624435

Efficient Autoregressive Inference for Transformer Probabilistic Models

Hassan, Loka, Li et al.
Transformer-based models for amortized probabilistic inference, such as neural processes, prior-fitted networks, and tabular foundation models, excel at single-pass marginal prediction. However, many real-world applications, from signal interpolation to multi-column tabular predictions, require coherent joint distributions that capture dependencies between predictions. While purely autoregressive architectures efficiently generate such distributions, they sacrifice the flexible set-conditioning that makes these models powerful for meta-learning. Conversely, the standard approach to obtain joint distributions from set-based models requires expensive re-encoding of the entire augmented conditioning set at each autoregressive step. We introduce a causal autoregressive buffer that preserves the advantages of both paradigms. Our approach decouples context encoding from updating the conditioning set. The model processes the context once and caches it. A dynamic buffer then captures target dependencies: as targets are incorporated, they enter the buffer and attend to both the cached context and previously buffered targets. This enables efficient batched autoregressive generation and one-pass joint log-likelihood evaluation. A unified training strategy allows seamless integration of set-based and autoregressive modes at minimal additional cost. Across synthetic functions, EEG signals, cognitive models, and tabular data, our method matches predictive accuracy of strong baselines while delivering up to 20 times faster joint sampling. Our approach combines the efficiency of autoregressive generative models with the representational power of set-based conditioning, making joint prediction practical for transformer-based probabilistic models.
academic

Efficient Autoregressive Inference for Transformer Probabilistic Models

基本信息

  • 论文ID: 2510.09477
  • 标题: Efficient Autoregressive Inference for Transformer Probabilistic Models
  • 作者: Conor Hassan, Nasrulloh Loka, Cen-You Li, Daolang Huang, Paul E. Chang, Yang Yang, Francesco Silvestrin, Samuel Kaski, Luigi Acerbi
  • 分类: stat.ML cs.LG
  • 发表时间: 2025年10月10日 (arXiv预印本)
  • 论文链接: https://arxiv.org/abs/2510.09477

摘要

基于Transformer的摊销概率推理模型(如神经过程、先验拟合网络和表格基础模型)在单次边际预测方面表现出色。然而,从信号插值到多列表格预测等许多实际应用需要捕获预测间依赖关系的连贯联合分布。纯自回归架构能高效生成此类分布,但牺牲了使这些模型在元学习中强大的灵活集合条件化能力。相反,从基于集合的模型获得联合分布的标准方法需要在每个自回归步骤中对整个增强条件集进行昂贵的重新编码。本文引入了因果自回归缓冲区,保留了两种范式的优势。该方法将上下文编码与条件集更新解耦,模型处理上下文一次并缓存,动态缓冲区捕获目标依赖关系。在合成函数、EEG信号、认知模型和表格数据上,该方法在匹配强基线预测准确性的同时,联合采样速度提升高达20倍。

研究背景与动机

核心问题

现有的基于Transformer的概率模型面临一个根本性的效率瓶颈:当需要生成联合分布时,必须在每个自回归步骤中重新编码整个条件集。具体而言:

  1. 集合条件化模型的局限:神经过程(NPs)、先验拟合网络(PFNs)等模型擅长边际预测,但在自回归部署时需要重复重新编码上下文,导致O(K(N+K)²)的计算复杂度
  2. 纯自回归模型的不足:虽然计算高效,但缺乏灵活的集合条件化能力,限制了在元学习任务中的应用

重要性

联合分布预测在多个关键应用中至关重要:

  • 信号插值中的时间依赖关系
  • 多列表格预测中的特征相关性
  • 行为数据建模中的序列依赖
  • 贝叶斯模型选择中的联合似然评估

现有方法局限性

  1. TNP-D自回归部署:每步需要重新编码增长的条件集
  2. TNP-A:训练和推理都需要处理重复的目标集,计算开销巨大
  3. TNP-ND:仅限于多元高斯分布,表达能力受限

核心贡献

  1. 提出因果自回归缓冲区机制:将集合条件化的上下文编码与序列预测解耦,实现高效的联合采样和似然评估
  2. 设计统一训练策略:使用掩码注意力和缓冲区大小课程学习,使单一模型能以最小额外成本学习两种操作模式
  3. 广泛适用性验证:在TNPs/PFNs和表格基础模型上实现高达20倍的联合采样加速,同时保持可比较的预测准确性
  4. 理论复杂度优化:将计算复杂度从O(K(N+K)²)降低到O(N²+NK+K²)

方法详解

任务定义

给定上下文集C = {(xₙ, yₙ)}ᴺₙ₌₁和目标集T = {(xₘ, yₘ)}ᴹₘ₌₁,目标是学习预测分布p_θ(y₁:ₘ|x₁:ₘ; C),其中θ为模型参数。

模型架构

核心组件

  1. 上下文编码器rC:处理上下文对,使用双向多头自注意力,缓存每层的键值对
  2. 缓冲区编码器rB:对缓冲区前缀使用严格因果多头自注意力
  3. 目标解码器rtgt:通过交叉注意力查询缓存的上下文和可见缓冲区前缀

预测分布参数化

p_θ(y*₁:K|x*₁:K; C) = ∏ᴷₖ₌₁ p_θ(y*ₖ|rtgt(x*ₖ, [rC(C), b₁:ₖ₋₁]))

其中bₖ = rB((xₖ, yₖ), rC(C), b₁:ₖ₋₁)

注意力掩码设计

实现四个关键要求:

  • (R1) 上下文不可变:一次编码并缓存为只读
  • (R2) 缓冲区严格因果:token j只能关注<j的位置
  • (R3) 信息单向流出上下文:无边写入C
  • (R4) 目标关注缓存上下文和可见缓冲区前缀

技术创新点

1. 解耦设计

  • 静态上下文缓存:一次编码,多次重用
  • 动态缓冲区:增量更新,捕获目标间依赖

2. 训练课程

  • 50%目标仅关注上下文
  • 50%目标关注上下文+随机长度缓冲区前缀
  • 确保模型在不同缓冲区状态下都能良好工作

3. 高效推理模式

  • 自回归采样:预填充上下文,序列解码目标
  • 联合似然评估:单次前向传播计算所有条件概率
  • 批量采样:共享上下文缓存,独立缓冲区状态

实验设置

数据集

  1. 合成函数
    • 高斯过程(GP):RBF、Matérn-3/2、Matérn-5/2核
    • 锯齿函数:非高斯,不连续导数
  2. EEG数据:11,520个试验,122个受试者,7个相关通道,256个时间点
  3. 多感官因果推理模型:音视觉定位实验数据,15名参与者
  4. 表格数据:UCI数据集(电力消耗、燃气轮机排放、自行车共享)

评价指标

  • 平均对数似然:评估预测质量
  • 壁钟时间:采样、似然评估、训练步骤的实际运行时间
  • 对数边际似然RMSE:模型选择任务的准确性

对比方法

  • TNP-D-Ind:独立预测,快速但无依赖建模
  • TNP-D-AR:自回归部署,表达力强但需重新编码
  • TNP-ND:多元高斯联合分布,表达力有限
  • TNP-A:完全自回归建模,训练和采样都很慢

实现细节

  • 优化器:Adam,学习率1×10⁻⁴
  • 架构:6层Transformer,4个注意力头,维度128
  • 预测头:20组件高斯混合模型
  • 缓冲区大小:K=16(主要实验)

实验结果

主要结果

计算效率

  • 自回归采样:比TNP-A和TNP-D-AR快3-20倍
  • 似然评估:与TNP-A相当,比TNP-D-AR快K倍
  • 训练速度:比TNP-A快4-12倍,与最快基线相当

预测准确性

数据集TNP-D-ARTNP-A本方法(K=16)本方法(K=1)
GP2.570.802.512.56
Sawtooth1.05-0.431.001.09
EEG-Int0.510.460.520.54
EEG-For1.07-0.040.851.21

消融实验

  • 缓冲区大小影响:K=1时等效于标准自回归,K=16时略有性能下降但速度大幅提升
  • 自定义Triton核:在大批量时提供显著加速
  • 注意力模式:即使禁用FlashAttention,TNP-A仍比其他方法慢数个数量级

案例分析

在多感官因果推理任务中:

  • 模型选择:LML RMSE为3.56,接近TNP-D-AR的3.47
  • 数据预测:平均对数似然为-2.76,与所有强基线相当
  • 与真实值相关性:R²=1.00(LML),R²=0.92(ΔLML)

相关工作

神经过程和先验拟合网络

本方法作为模块化组件,可集成到现有NP/PFN架构中。与专注于上下文集可扩展性的先前工作互补,本文针对自回归联合采样效率。

Transformer概率模型

构建在将贝叶斯推理框架为上下文学习任务的趋势之上,利用基于Transformer的NP和PFN变体。

表格基础模型

与TabPFN和TabICL等模型自然集成,为高效联合预测提供补充模块。

自回归联合密度估计

与TNP-A相关但有关键区别:TNP-A在训练和推理中都使用目标重复,而本方法仅在似然评估时需要。

结论与讨论

主要结论

  1. 效率突破:成功将自回归Transformer的效率引入NP/PFN框架
  2. 性能保持:在大幅提升速度的同时保持预测准确性
  3. 广泛适用:在多个领域和任务中验证了方法的有效性

局限性

  1. 缓冲区长度扩展:K增大时仍有O(K²)项,且当前使用固定位置嵌入
  2. 长缓冲区质量漂移:相比每步重新编码的精确自回归可能有质量下降
  3. 内存占用:需要维护上下文缓存和缓冲区状态

未来方向

  1. 位置编码改进:使用RoPE或ALiBi支持更长序列
  2. 推测解码:借鉴draft-verify过程的自适应推理策略
  3. 参数高效微调:使用适配器或LoRA为预训练模型添加缓冲区功能

深度评价

优点

  1. 创新性强:巧妙解决了集合条件化与自回归效率的权衡问题
  2. 理论扎实:提供了清晰的复杂度分析和数学推导
  3. 实验全面:涵盖合成数据、真实数据、多个应用领域
  4. 工程优化:包含自定义CUDA核等底层优化
  5. 可复现性:提供详细的实现细节和将开源代码

不足

  1. 适用范围:主要适用于中等长度的目标序列,超长序列仍面临挑战
  2. 理论分析:缺乏缓冲区近似误差的理论界限分析
  3. 对比实验:未与最新的高效注意力机制(如线性注意力)对比

影响力

  1. 学术价值:为概率模型的高效推理提供了新思路
  2. 实用价值:显著降低了联合预测的计算成本,使实际应用成为可能
  3. 可扩展性:方法具有良好的通用性,可应用于多种Transformer变体

适用场景

  • 需要频繁联合采样的应用(如不确定性量化)
  • 大规模上下文的序列预测任务
  • 实时推理要求较高的场景
  • 多模态数据的联合建模

参考文献

主要参考文献包括:

  • Garnelo et al. (2018): Neural Processes原始论文
  • Nguyen & Grover (2022): Transformer Neural Processes
  • Müller et al. (2022): Prior-Fitted Networks
  • Bruinsma et al. (2023): Autoregressive Conditional Neural Processes
  • Jingang et al. (2025): TabICL表格基础模型

总体评价:这是一篇高质量的研究论文,在理论创新、实验验证和工程实现方面都表现出色。该方法成功解决了概率模型中的一个重要效率瓶颈,具有广泛的应用前景和学术价值。