2025-11-25T22:34:18.624435

Efficient Autoregressive Inference for Transformer Probabilistic Models

Hassan, Loka, Li et al.
Transformer-based models for amortized probabilistic inference, such as neural processes, prior-fitted networks, and tabular foundation models, excel at single-pass marginal prediction. However, many real-world applications, from signal interpolation to multi-column tabular predictions, require coherent joint distributions that capture dependencies between predictions. While purely autoregressive architectures efficiently generate such distributions, they sacrifice the flexible set-conditioning that makes these models powerful for meta-learning. Conversely, the standard approach to obtain joint distributions from set-based models requires expensive re-encoding of the entire augmented conditioning set at each autoregressive step. We introduce a causal autoregressive buffer that preserves the advantages of both paradigms. Our approach decouples context encoding from updating the conditioning set. The model processes the context once and caches it. A dynamic buffer then captures target dependencies: as targets are incorporated, they enter the buffer and attend to both the cached context and previously buffered targets. This enables efficient batched autoregressive generation and one-pass joint log-likelihood evaluation. A unified training strategy allows seamless integration of set-based and autoregressive modes at minimal additional cost. Across synthetic functions, EEG signals, cognitive models, and tabular data, our method matches predictive accuracy of strong baselines while delivering up to 20 times faster joint sampling. Our approach combines the efficiency of autoregressive generative models with the representational power of set-based conditioning, making joint prediction practical for transformer-based probabilistic models.
academic

Transformer 확률 모델을 위한 효율적인 자회귀 추론

기본 정보

  • 논문 ID: 2510.09477
  • 제목: Efficient Autoregressive Inference for Transformer Probabilistic Models
  • 저자: Conor Hassan, Nasrulloh Loka, Cen-You Li, Daolang Huang, Paul E. Chang, Yang Yang, Francesco Silvestrin, Samuel Kaski, Luigi Acerbi
  • 분류: stat.ML cs.LG
  • 발표 시간: 2025년 10월 10일 (arXiv 사전인쇄본)
  • 논문 링크: https://arxiv.org/abs/2510.09477

초록

Transformer 기반의 상각화된 확률 추론 모델(신경 과정, 사전 적합 네트워크, 표 기초 모델 등)은 단일 주변 예측에서 우수한 성능을 보입니다. 그러나 신호 보간에서 다중 열 표 예측에 이르기까지 많은 실제 응용 프로그램은 예측 간 종속성을 포착하는 일관된 결합 분포가 필요합니다. 순수 자회귀 아키텍처는 이러한 분포를 효율적으로 생성할 수 있지만, 이러한 모델을 메타 학습에서 강력하게 만드는 유연한 집합 조건화 능력을 포기합니다. 반대로, 집합 기반 모델에서 결합 분포를 얻는 표준 방법은 각 자회귀 단계에서 전체 증강 조건 집합을 비용이 많이 드는 재인코딩이 필요합니다. 본 논문은 인과 자회귀 버퍼를 도입하여 두 패러다임의 장점을 보존합니다. 이 방법은 컨텍스트 인코딩을 조건 집합 업데이트와 분리하여, 모델이 컨텍스트를 한 번 처리하고 캐시하며, 동적 버퍼가 목표 종속성을 포착합니다. 합성 함수, EEG 신호, 인지 모델 및 표 데이터에서 이 방법은 강력한 기준선 예측 정확도와 일치하면서 결합 샘플링 속도를 최대 20배 향상시킵니다.

연구 배경 및 동기

핵심 문제

기존 Transformer 기반 확률 모델은 근본적인 효율성 병목에 직면합니다: 결합 분포를 생성해야 할 때 각 자회귀 단계에서 전체 조건 집합을 재인코딩해야 합니다. 구체적으로:

  1. 집합 조건화 모델의 한계: 신경 과정(NPs), 사전 적합 네트워크(PFNs) 등의 모델은 주변 예측에 능숙하지만 자회귀 배포 시 컨텍스트를 반복적으로 재인코딩해야 하므로 O(K(N+K)²)의 계산 복잡도가 발생합니다.
  2. 순수 자회귀 모델의 부족: 계산 효율적이지만 유연한 집합 조건화 능력이 부족하여 메타 학습 작업에서의 응용이 제한됩니다.

중요성

결합 분포 예측은 여러 중요한 응용에서 필수적입니다:

  • 신호 보간의 시간 종속성
  • 다중 열 표 예측의 특성 상관성
  • 행동 데이터 모델링의 순차 종속성
  • 베이지안 모델 선택의 결합 우도 평가

기존 방법의 한계

  1. TNP-D 자회귀 배포: 각 단계에서 증가하는 조건 집합을 재인코딩해야 함
  2. TNP-A: 훈련과 추론 모두 반복된 목표 집합을 처리해야 하므로 계산 오버헤드가 큼
  3. TNP-ND: 다변량 가우스 분포로만 제한되어 표현 능력이 제한됨

핵심 기여

  1. 인과 자회귀 버퍼 메커니즘 제안: 집합 조건화의 컨텍스트 인코딩을 순차 예측과 분리하여 효율적인 결합 샘플링 및 우도 평가 구현
  2. 통합 훈련 전략 설계: 마스크된 주의와 버퍼 크기 커리큘럼 학습을 사용하여 단일 모델이 최소 추가 비용으로 두 가지 작동 모드를 학습하도록 함
  3. 광범위한 적용 가능성 검증: TNPs/PFNs 및 표 기초 모델에서 결합 샘플링 가속을 최대 20배 달성하면서 비교 가능한 예측 정확도 유지
  4. 이론적 복잡도 최적화: 계산 복잡도를 O(K(N+K)²)에서 O(N²+NK+K²)로 감소

방법 상세 설명

작업 정의

컨텍스트 집합 C = {(xₙ, yₙ)}ᴺₙ₌₁과 목표 집합 T = {(xₘ, yₘ)}ᴹₘ₌₁이 주어졌을 때, 목표는 예측 분포 p_θ(y₁:ₘ|x₁:ₘ; C)를 학습하는 것입니다. 여기서 θ는 모델 매개변수입니다.

모델 아키텍처

핵심 구성 요소

  1. 컨텍스트 인코더 rC: 컨텍스트 쌍을 처리하며, 양방향 다중 헤드 자기 주의를 사용하고 각 계층의 키-값 쌍을 캐시합니다.
  2. 버퍼 인코더 rB: 버퍼 접두사에 엄격한 인과 다중 헤드 자기 주의를 사용합니다.
  3. 목표 디코더 rtgt: 교차 주의를 통해 캐시된 컨텍스트 및 가시 버퍼 접두사를 쿼리합니다.

예측 분포 매개변수화

p_θ(y*₁:K|x*₁:K; C) = ∏ᴷₖ₌₁ p_θ(y*ₖ|rtgt(x*ₖ, [rC(C), b₁:ₖ₋₁]))

여기서 bₖ = rB((xₖ, yₖ), rC(C), b₁:ₖ₋₁)

주의 마스크 설계

네 가지 핵심 요구 사항 구현:

  • (R1) 컨텍스트 불변성: 한 번 인코딩되고 읽기 전용으로 캐시됨
  • (R2) 버퍼 엄격한 인과성: 토큰 j는 <j 위치만 관심 가능
  • (R3) 정보 단방향 흐름: C로의 역방향 쓰기 없음
  • (R4) 목표 관심: 캐시된 컨텍스트 및 가시 버퍼 접두사에 관심

기술 혁신 포인트

1. 분리 설계

  • 정적 컨텍스트 캐시: 한 번 인코딩, 여러 번 재사용
  • 동적 버퍼: 증분 업데이트, 목표 간 종속성 포착

2. 훈련 커리큘럼

  • 50% 목표는 컨텍스트만 관심
  • 50% 목표는 컨텍스트 + 무작위 길이 버퍼 접두사에 관심
  • 다양한 버퍼 상태에서 모델이 잘 작동하도록 보장

3. 효율적인 추론 모드

  • 자회귀 샘플링: 컨텍스트 사전 채우기, 순차 목표 디코딩
  • 결합 우도 평가: 단일 전방 전달로 모든 조건부 확률 계산
  • 배치 샘플링: 공유 컨텍스트 캐시, 독립적 버퍼 상태

실험 설정

데이터 집합

  1. 합성 함수:
    • 가우스 과정(GP): RBF, Matérn-3/2, Matérn-5/2 커널
    • 톱니 함수: 비가우스, 불연속 도함수
  2. EEG 데이터: 11,520개 시행, 122명 피험자, 7개 관련 채널, 256개 시간점
  3. 다중감각 인과 추론 모델: 음시각 위치 결정 실험 데이터, 15명 참여자
  4. 표 데이터: UCI 데이터 집합(전력 소비, 가스 터빈 배출, 자전거 공유)

평가 지표

  • 평균 로그 우도: 예측 품질 평가
  • 벽시계 시간: 샘플링, 우도 평가, 훈련 단계의 실제 실행 시간
  • 로그 주변 우도 RMSE: 모델 선택 작업의 정확도

비교 방법

  • TNP-D-Ind: 독립 예측, 빠르지만 종속성 모델링 없음
  • TNP-D-AR: 자회귀 배포, 표현력 강하지만 재인코딩 필요
  • TNP-ND: 다변량 가우스 결합 분포, 표현력 제한
  • TNP-A: 완전 자회귀 모델링, 훈련 및 샘플링 모두 느림

구현 세부 사항

  • 최적화기: Adam, 학습률 1×10⁻⁴
  • 아키텍처: 6층 Transformer, 4개 주의 헤드, 차원 128
  • 예측 헤드: 20 성분 가우스 혼합 모델
  • 버퍼 크기: K=16(주요 실험)

실험 결과

주요 결과

계산 효율성

  • 자회귀 샘플링: TNP-A 및 TNP-D-AR보다 3-20배 빠름
  • 우도 평가: TNP-A와 비슷하며 TNP-D-AR보다 K배 빠름
  • 훈련 속도: TNP-A보다 4-12배 빠르며 가장 빠른 기준선과 비슷함

예측 정확도

데이터 집합TNP-D-ARTNP-A본 방법(K=16)본 방법(K=1)
GP2.570.802.512.56
Sawtooth1.05-0.431.001.09
EEG-Int0.510.460.520.54
EEG-For1.07-0.040.851.21

제거 실험

  • 버퍼 크기 영향: K=1일 때 표준 자회귀와 동등하며, K=16일 때 약간의 성능 저하이지만 속도 대폭 향상
  • 사용자 정의 Triton 커널: 대규모 배치에서 상당한 가속 제공
  • 주의 패턴: FlashAttention을 비활성화해도 TNP-A는 여전히 다른 방법보다 수 배 느림

사례 분석

다중감각 인과 추론 작업에서:

  • 모델 선택: LML RMSE 3.56으로 TNP-D-AR의 3.47에 근접
  • 데이터 예측: 평균 로그 우도 -2.76으로 모든 강력한 기준선과 비슷함
  • 실제 값과의 상관성: R²=1.00(LML), R²=0.92(ΔLML)

관련 연구

신경 과정 및 사전 적합 네트워크

본 방법은 모듈식 구성 요소로 기존 NP/PFN 아키텍처에 통합될 수 있습니다. 컨텍스트 집합 확장성에 초점을 맞춘 이전 작업을 보완하면서, 본 논문은 자회귀 결합 샘플링 효율성을 다룹니다.

Transformer 확률 모델

베이지안 추론 프레임워크를 컨텍스트 학습 작업으로 구성하는 추세를 바탕으로 구축되며, Transformer 기반 NP 및 PFN 변형을 활용합니다.

표 기초 모델

TabPFN 및 TabICL과 같은 모델과 자연스럽게 통합되어 효율적인 결합 예측을 위한 보완 모듈을 제공합니다.

자회귀 결합 밀도 추정

TNP-A와 관련이 있지만 핵심 차이점이 있습니다: TNP-A는 훈련 및 추론 모두에서 목표 반복을 사용하는 반면, 본 방법은 우도 평가 시에만 필요합니다.

결론 및 논의

주요 결론

  1. 효율성 돌파: 자회귀 Transformer의 효율성을 NP/PFN 프레임워크에 성공적으로 도입
  2. 성능 유지: 속도를 대폭 향상시키면서 예측 정확도 유지
  3. 광범위한 적용 가능성: 여러 영역 및 작업에서 방법의 효과성 검증

한계

  1. 버퍼 길이 확장: K가 증가할 때 여전히 O(K²) 항이 있으며, 현재 고정 위치 임베딩 사용
  2. 긴 버퍼 품질 드리프트: 각 단계 재인코딩의 정확한 자회귀와 비교하여 품질 저하 가능성
  3. 메모리 점유: 컨텍스트 캐시 및 버퍼 상태 유지 필요

향후 방향

  1. 위치 인코딩 개선: RoPE 또는 ALiBi를 사용하여 더 긴 시퀀스 지원
  2. 추측 디코딩: 초안-검증 프로세스의 적응형 추론 전략 차용
  3. 매개변수 효율적 미세 조정: 어댑터 또는 LoRA를 사용하여 사전 훈련된 모델에 버퍼 기능 추가

심층 평가

장점

  1. 높은 혁신성: 집합 조건화와 자회귀 효율성의 균형을 교묘하게 해결
  2. 견고한 이론: 명확한 복잡도 분석 및 수학적 유도 제공
  3. 포괄적인 실험: 합성 데이터, 실제 데이터, 여러 응용 영역 포함
  4. 공학 최적화: 사용자 정의 CUDA 커널 등 저수준 최적화 포함
  5. 재현 가능성: 상세한 구현 세부 사항 및 오픈 소스 코드 제공 예정

부족한 점

  1. 적용 범위: 주로 중간 길이의 목표 시퀀스에 적용 가능하며, 초장 시퀀스는 여전히 도전 과제
  2. 이론적 분석: 버퍼 근사 오류의 이론적 경계 분석 부족
  3. 비교 실험: 선형 주의 등 최신 효율적 주의 메커니즘과의 비교 부족

영향력

  1. 학술적 가치: 확률 모델의 효율적 추론에 새로운 관점 제공
  2. 실용적 가치: 결합 예측의 계산 비용을 크게 감소시켜 실제 응용 가능
  3. 확장성: 다양한 Transformer 변형에 적용 가능한 우수한 범용성

적용 시나리오

  • 빈번한 결합 샘플링이 필요한 응용(예: 불확실성 정량화)
  • 대규모 컨텍스트의 순차 예측 작업
  • 실시간 추론 요구 사항이 높은 시나리오
  • 다중 모달 데이터의 결합 모델링

참고 문헌

주요 참고 문헌 포함:

  • Garnelo et al. (2018): 신경 과정 원본 논문
  • Nguyen & Grover (2022): Transformer 신경 과정
  • Müller et al. (2022): 사전 적합 네트워크
  • Bruinsma et al. (2023): 자회귀 조건부 신경 과정
  • Jingang et al. (2025): TabICL 표 기초 모델

전체 평가: 이것은 이론적 혁신, 실험 검증 및 공학 구현 측면에서 모두 우수한 고품질 연구 논문입니다. 이 방법은 확률 모델의 중요한 효율성 병목을 성공적으로 해결하며 광범위한 응용 전망과 학술적 가치를 가집니다.