2025-11-25T22:34:18.624435

Efficient Autoregressive Inference for Transformer Probabilistic Models

Hassan, Loka, Li et al.
Transformer-based models for amortized probabilistic inference, such as neural processes, prior-fitted networks, and tabular foundation models, excel at single-pass marginal prediction. However, many real-world applications, from signal interpolation to multi-column tabular predictions, require coherent joint distributions that capture dependencies between predictions. While purely autoregressive architectures efficiently generate such distributions, they sacrifice the flexible set-conditioning that makes these models powerful for meta-learning. Conversely, the standard approach to obtain joint distributions from set-based models requires expensive re-encoding of the entire augmented conditioning set at each autoregressive step. We introduce a causal autoregressive buffer that preserves the advantages of both paradigms. Our approach decouples context encoding from updating the conditioning set. The model processes the context once and caches it. A dynamic buffer then captures target dependencies: as targets are incorporated, they enter the buffer and attend to both the cached context and previously buffered targets. This enables efficient batched autoregressive generation and one-pass joint log-likelihood evaluation. A unified training strategy allows seamless integration of set-based and autoregressive modes at minimal additional cost. Across synthetic functions, EEG signals, cognitive models, and tabular data, our method matches predictive accuracy of strong baselines while delivering up to 20 times faster joint sampling. Our approach combines the efficiency of autoregressive generative models with the representational power of set-based conditioning, making joint prediction practical for transformer-based probabilistic models.
academic

Transformer確率モデルの効率的な自回帰推論

基本情報

  • 論文ID: 2510.09477
  • タイトル: Efficient Autoregressive Inference for Transformer Probabilistic Models
  • 著者: Conor Hassan, Nasrulloh Loka, Cen-You Li, Daolang Huang, Paul E. Chang, Yang Yang, Francesco Silvestrin, Samuel Kaski, Luigi Acerbi
  • 分類: stat.ML cs.LG
  • 発表日: 2025年10月10日 (arXivプレプリント)
  • 論文リンク: https://arxiv.org/abs/2510.09477

要約

Transformerベースの償却確率推論モデル(ニューラルプロセス、先験適合ネットワーク、表形式基礎モデルなど)は、単一の周辺予測において優れた性能を示しています。しかし、信号補間から多列表形式予測まで、多くの実用的なアプリケーションでは、予測間の依存関係を捉える一貫した同時分布が必要です。純粋な自回帰アーキテクチャはこのような分布を効率的に生成できますが、これらのモデルをメタラーニングで強力にする柔軟な集合条件付け能力を犠牲にしています。対照的に、集合ベースのモデルから同時分布を得る標準的な方法は、各自回帰ステップで拡張された条件付きセット全体の高価な再エンコーディングが必要です。本論文では因果自回帰バッファを導入し、両方のパラダイムの利点を保持します。この方法は文脈エンコーディングと条件付きセット更新を分離し、モデルが文脈を一度処理してキャッシュし、動的バッファが目標依存関係を捉えます。合成関数、EEG信号、認知モデル、表形式データ上で、この方法は強力なベースライン予測精度と一致させながら、同時サンプリング速度を最大20倍向上させます。

研究背景と動機

核心的な問題

既存のTransformerベースの確率モデルは、同時分布を生成する必要がある場合に根本的な効率ボトルネックに直面しています。具体的には:

  1. 集合条件付けモデルの制限:ニューラルプロセス(NP)、先験適合ネットワーク(PFN)などのモデルは周辺予測に優れていますが、自回帰配置では文脈の繰り返し再エンコーディングが必要であり、O(K(N+K)²)の計算複雑度をもたらします
  2. 純粋な自回帰モデルの不足:計算効率は高いですが、柔軟な集合条件付け能力が不足しており、メタラーニングタスクでの応用が制限されています

重要性

同時分布予測は複数の重要なアプリケーションで不可欠です:

  • 信号補間における時間依存関係
  • 多列表形式予測における特徴相関
  • 行動データモデリングにおける系列依存性
  • ベイズモデル選択における同時尤度評価

既存手法の制限

  1. TNP-D自回帰配置:各ステップで拡張される条件付きセットの再エンコーディングが必要
  2. TNP-A:訓練と推論の両方で繰り返される目標セットの処理が必要で、計算オーバーヘッドが大きい
  3. TNP-ND:多変量ガウス分布に限定され、表現能力が制限されている

核心的な貢献

  1. 因果自回帰バッファメカニズムの提案:集合条件付けの文脈エンコーディングと系列予測を分離し、効率的な同時サンプリングと尤度評価を実現
  2. 統一訓練戦略の設計:マスク付き注意とバッファサイズカリキュラム学習を使用し、単一モデルが最小限の追加コストで両方の動作モードを学習できるようにする
  3. 広範な適用性の検証:TNP/PFNおよび表形式基礎モデル上で最大20倍の同時サンプリング加速を実現しながら、比較可能な予測精度を維持
  4. 理論的複雑度の最適化:計算複雑度をO(K(N+K)²)からO(N²+NK+K²)に削減

方法の詳細

タスク定義

文脈セットC = {(xₙ, yₙ)}ᴺₙ₌₁と目標セットT = {(xₘ, yₘ)}ᴹₘ₌₁が与えられた場合、目標は予測分布p_θ(y₁:ₘ|x₁:ₘ; C)を学習することです。ここでθはモデルパラメータです。

モデルアーキテクチャ

核心コンポーネント

  1. 文脈エンコーダrC:文脈ペアを処理し、双方向マルチヘッド自己注意を使用し、各層のキー値ペアをキャッシュ
  2. バッファエンコーダrB:バッファプレフィックスに厳密な因果マルチヘッド自己注意を使用
  3. 目標デコーダrtgt:キャッシュされた文脈と可視バッファプレフィックスをクロスアテンションで照会

予測分布のパラメータ化

p_θ(y*₁:K|x*₁:K; C) = ∏ᴷₖ₌₁ p_θ(y*ₖ|rtgt(x*ₖ, [rC(C), b₁:ₖ₋₁]))

ここでbₖ = rB((xₖ, yₖ), rC(C), b₁:ₖ₋₁)

注意マスク設計

4つの重要な要件を実装:

  • (R1) 文脈の不変性:一度エンコードしてキャッシュとして読み取り専用
  • (R2) バッファの厳密な因果性:トークンjは<jの位置のみに注意可能
  • (R3) 文脈からの情報の一方向フロー:Cへの逆方向書き込みなし
  • (R4) 目標はキャッシュされた文脈と可視バッファプレフィックスに注意

技術的革新点

1. 分離設計

  • 静的文脈キャッシュ:一度エンコード、複数回再利用
  • 動的バッファ:段階的更新、目標間依存関係を捉える

2. 訓練カリキュラム

  • 50%の目標は文脈のみに注意
  • 50%の目標は文脈+ランダム長バッファプレフィックスに注意
  • 異なるバッファ状態でモデルが良好に機能することを保証

3. 効率的な推論モード

  • 自回帰サンプリング:文脈を事前入力、目標を系列デコード
  • 同時尤度評価:単一フォワードパスで全条件付き確率を計算
  • バッチサンプリング:文脈キャッシュを共有、独立したバッファ状態

実験設定

データセット

  1. 合成関数
    • ガウス過程(GP):RBF、Matérn-3/2、Matérn-5/2カーネル
    • のこぎり波関数:非ガウス、不連続導関数
  2. EEGデータ:11,520試行、122被験者、7関連チャネル、256時間ポイント
  3. 多感覚因果推論モデル:音視覚定位実験データ、15参加者
  4. 表形式データ:UCIデータセット(電力消費、ガスタービン排出、自転車シェアリング)

評価指標

  • 平均対数尤度:予測品質の評価
  • 実時間:サンプリング、尤度評価、訓練ステップの実際の実行時間
  • 対数周辺尤度RMSE:モデル選択タスクの精度

比較手法

  • TNP-D-Ind:独立予測、高速だが依存性モデリングなし
  • TNP-D-AR:自回帰配置、表現力強いが再エンコーディング必要
  • TNP-ND:多変量ガウス同時分布、表現力限定
  • TNP-A:完全自回帰モデリング、訓練とサンプリング両方遅い

実装詳細

  • 最適化器:Adam、学習率1×10⁻⁴
  • アーキテクチャ:6層Transformer、4注意ヘッド、次元128
  • 予測ヘッド:20成分ガウス混合モデル
  • バッファサイズ:K=16(主要実験)

実験結果

主要結果

計算効率

  • 自回帰サンプリング:TNP-AおよびTNP-D-ARより3~20倍高速
  • 尤度評価:TNP-Aと同等、TNP-D-ARよりK倍高速
  • 訓練速度:TNP-Aより4~12倍高速、最速ベースラインと同等

予測精度

データセットTNP-D-ARTNP-A本手法(K=16)本手法(K=1)
GP2.570.802.512.56
のこぎり波1.05-0.431.001.09
EEG-Int0.510.460.520.54
EEG-For1.07-0.040.851.21

アブレーション実験

  • バッファサイズの影響:K=1時は標準自回帰と等価、K=16時は性能がわずかに低下するが速度が大幅に向上
  • カスタムTritonカーネル:大バッチ時に顕著な加速を提供
  • 注意パターン:FlashAttentionを無効にしてもTNP-Aは他の手法より数桁遅い

ケーススタディ

多感覚因果推論タスクにおいて:

  • モデル選択:LML RMSEは3.56、TNP-D-ARの3.47に近い
  • データ予測:平均対数尤度は-2.76、全強力ベースラインと同等
  • 真値との相関:R²=1.00(LML)、R²=0.92(ΔLML)

関連研究

ニューラルプロセスと先験適合ネットワーク

本手法はモジュール化コンポーネントとして、既存のNP/PFNアーキテクチャに統合できます。文脈セットスケーラビリティに焦点を当てた先行研究を補完し、本論文は自回帰同時サンプリング効率に対応しています。

Transformer確率モデル

ベイズ推論フレームワークを文脈学習タスクとして位置付ける傾向に基づいており、Transformerベースのニューラルプロセスおよび先験適合ネットワーク変種を活用しています。

表形式基礎モデル

TabPFNおよびTabICLなどのモデルと自然に統合され、効率的な同時予測のための補完的なモジュールを提供します。

自回帰同時密度推定

TNP-Aと関連していますが重要な違いがあります:TNP-Aは訓練と推論の両方で目標の繰り返しを使用しますが、本手法は尤度評価時のみ必要です。

結論と議論

主要な結論

  1. 効率の突破口:自回帰Transformerの効率をNP/PFNフレームワークに成功裏に導入
  2. 性能の維持:速度を大幅に向上させながら予測精度を維持
  3. 広範な適用性:複数の領域とタスクで手法の有効性を検証

制限事項

  1. バッファ長スケーリング:Kが増加するとO(K²)項が残り、現在は固定位置埋め込みを使用
  2. 長バッファ品質ドリフト:各ステップ再エンコーディングの正確な自回帰と比較して品質低下の可能性
  3. メモリ占有:文脈キャッシュとバッファ状態の維持が必要

今後の方向性

  1. 位置エンコーディング改善:RoPEまたはALiBiを使用してより長い系列をサポート
  2. 推測デコーディング:ドラフト検証プロセスから借用した適応推論戦略
  3. パラメータ効率的ファインチューニング:アダプタまたはLoRAを使用して事前訓練モデルにバッファ機能を追加

深い評価

強み

  1. 革新性が高い:集合条件付けと自回帰効率のトレードオフを巧妙に解決
  2. 理論が堅実:明確な複雑度分析と数学的導出を提供
  3. 実験が包括的:合成データ、実データ、複数のアプリケーション領域をカバー
  4. 工学的最適化:カスタムCUDAカーネルなど下層の最適化を含む
  5. 再現性:詳細な実装詳細を提供し、コードをオープンソース化予定

不足点

  1. 適用範囲:主に中程度の長さの目標系列に適用可能、超長系列は依然として課題
  2. 理論分析:バッファ近似誤差の理論的界限分析が不足
  3. 比較実験:線形注意など最新の効率的注意メカニズムとの比較がない

影響力

  1. 学術的価値:確率モデルの効率的推論に新しい視点を提供
  2. 実用的価値:同時予測の計算コストを大幅に削減し、実用的なアプリケーションを可能に
  3. スケーラビリティ:手法は汎用性が高く、複数のTransformer変種に適用可能

適用シーン

  • 頻繁な同時サンプリングが必要なアプリケーション(不確実性定量化など)
  • 大規模文脈の系列予測タスク
  • リアルタイム推論要件が高いシーン
  • マルチモーダルデータの同時モデリング

参考文献

主要な参考文献には以下が含まれます:

  • Garnelo et al. (2018): ニューラルプロセス原論文
  • Nguyen & Grover (2022): Transformer Neural Processes
  • Müller et al. (2022): Prior-Fitted Networks
  • Bruinsma et al. (2023): Autoregressive Conditional Neural Processes
  • Jingang et al. (2025): TabICL表形式基礎モデル

総合評価:これは理論的革新、実験検証、工学的実装の全ての面で優れた高品質な研究論文です。本手法は確率モデルにおける重要な効率ボトルネックを成功裏に解決し、広範な応用前景と学術的価値を有しています。