Efficient Autoregressive Inference for Transformer Probabilistic Models
Hassan, Loka, Li et al.
Transformer-based models for amortized probabilistic inference, such as neural processes, prior-fitted networks, and tabular foundation models, excel at single-pass marginal prediction. However, many real-world applications, from signal interpolation to multi-column tabular predictions, require coherent joint distributions that capture dependencies between predictions. While purely autoregressive architectures efficiently generate such distributions, they sacrifice the flexible set-conditioning that makes these models powerful for meta-learning. Conversely, the standard approach to obtain joint distributions from set-based models requires expensive re-encoding of the entire augmented conditioning set at each autoregressive step. We introduce a causal autoregressive buffer that preserves the advantages of both paradigms. Our approach decouples context encoding from updating the conditioning set. The model processes the context once and caches it. A dynamic buffer then captures target dependencies: as targets are incorporated, they enter the buffer and attend to both the cached context and previously buffered targets. This enables efficient batched autoregressive generation and one-pass joint log-likelihood evaluation. A unified training strategy allows seamless integration of set-based and autoregressive modes at minimal additional cost. Across synthetic functions, EEG signals, cognitive models, and tabular data, our method matches predictive accuracy of strong baselines while delivering up to 20 times faster joint sampling. Our approach combines the efficiency of autoregressive generative models with the representational power of set-based conditioning, making joint prediction practical for transformer-based probabilistic models.
Transformer 기반의 상각화된 확률 추론 모델(신경 과정, 사전 적합 네트워크, 표 기초 모델 등)은 단일 주변 예측에서 우수한 성능을 보입니다. 그러나 신호 보간에서 다중 열 표 예측에 이르기까지 많은 실제 응용 프로그램은 예측 간 종속성을 포착하는 일관된 결합 분포가 필요합니다. 순수 자회귀 아키텍처는 이러한 분포를 효율적으로 생성할 수 있지만, 이러한 모델을 메타 학습에서 강력하게 만드는 유연한 집합 조건화 능력을 포기합니다. 반대로, 집합 기반 모델에서 결합 분포를 얻는 표준 방법은 각 자회귀 단계에서 전체 증강 조건 집합을 비용이 많이 드는 재인코딩이 필요합니다. 본 논문은 인과 자회귀 버퍼를 도입하여 두 패러다임의 장점을 보존합니다. 이 방법은 컨텍스트 인코딩을 조건 집합 업데이트와 분리하여, 모델이 컨텍스트를 한 번 처리하고 캐시하며, 동적 버퍼가 목표 종속성을 포착합니다. 합성 함수, EEG 신호, 인지 모델 및 표 데이터에서 이 방법은 강력한 기준선 예측 정확도와 일치하면서 결합 샘플링 속도를 최대 20배 향상시킵니다.