Efficient Autoregressive Inference for Transformer Probabilistic Models
Hassan, Loka, Li et al.
Transformer-based models for amortized probabilistic inference, such as neural processes, prior-fitted networks, and tabular foundation models, excel at single-pass marginal prediction. However, many real-world applications, from signal interpolation to multi-column tabular predictions, require coherent joint distributions that capture dependencies between predictions. While purely autoregressive architectures efficiently generate such distributions, they sacrifice the flexible set-conditioning that makes these models powerful for meta-learning. Conversely, the standard approach to obtain joint distributions from set-based models requires expensive re-encoding of the entire augmented conditioning set at each autoregressive step. We introduce a causal autoregressive buffer that preserves the advantages of both paradigms. Our approach decouples context encoding from updating the conditioning set. The model processes the context once and caches it. A dynamic buffer then captures target dependencies: as targets are incorporated, they enter the buffer and attend to both the cached context and previously buffered targets. This enables efficient batched autoregressive generation and one-pass joint log-likelihood evaluation. A unified training strategy allows seamless integration of set-based and autoregressive modes at minimal additional cost. Across synthetic functions, EEG signals, cognitive models, and tabular data, our method matches predictive accuracy of strong baselines while delivering up to 20 times faster joint sampling. Our approach combines the efficiency of autoregressive generative models with the representational power of set-based conditioning, making joint prediction practical for transformer-based probabilistic models.
academic
Inferencia Autorregresiva Eficiente para Modelos Probabilísticos Transformer
Los modelos de inferencia probabilística amortizada basados en Transformer (como Procesos Neurales, Redes Ajustadas Previas y Modelos Fundamentales Tabulares) demuestran un excelente desempeño en predicciones marginales únicas. Sin embargo, muchas aplicaciones prácticas, desde la interpolación de señales hasta la predicción de múltiples columnas tabulares, requieren capturar distribuciones conjuntas coherentes que modelan dependencias entre predicciones. Las arquitecturas puramente autorregresivas pueden generar eficientemente tales distribuciones, pero sacrifican la capacidad flexible de acondicionamiento por conjuntos que hace que estos modelos sean potentes en meta-aprendizaje. Por el contrario, el método estándar para obtener distribuciones conjuntas de modelos basados en conjuntos requiere una recodificación costosa del conjunto de condiciones aumentado completo en cada paso autorregresivo. Este artículo introduce el búfer autorregresivo causal, que retiene las ventajas de ambos paradigmas. El método desacopla la codificación de contexto de las actualizaciones del conjunto de condiciones, permitiendo que el modelo procese el contexto una sola vez y lo almacene en caché, mientras que un búfer dinámico captura las dependencias entre objetivos. En funciones sintéticas, señales EEG, modelos cognitivos y datos tabulares, el método logra aceleraciones de hasta 20 veces en la velocidad de muestreo conjunto mientras mantiene la precisión de predicción comparable a líneas base sólidas.
Los modelos probabilísticos basados en Transformer existentes enfrentan un cuello de botella de eficiencia fundamental: cuando es necesario generar una distribución conjunta, se debe recodificar todo el conjunto de condiciones en cada paso autorregresivo. Específicamente:
Limitaciones de los modelos acondicionados por conjuntos: Los Procesos Neurales (NP), Redes Ajustadas Previas (PFN) y modelos similares sobresalen en predicción marginal, pero requieren recodificación repetida del contexto durante el despliegue autorregresivo, resultando en una complejidad computacional O(K(N+K)²)
Insuficiencias de los modelos puramente autorregresivos: Aunque son computacionalmente eficientes, carecen de capacidad flexible de acondicionamiento por conjuntos, limitando su aplicación en tareas de meta-aprendizaje
Propone el mecanismo de búfer autorregresivo causal: Desacopla la codificación de contexto de acondicionamiento por conjuntos de la predicción secuencial, permitiendo muestreo conjunto eficiente y evaluación de verosimilitud
Diseña una estrategia de entrenamiento unificada: Utiliza enmascaramiento de atención y aprendizaje por currículo de tamaño de búfer, permitiendo que un único modelo aprenda ambos modos de operación con costo adicional mínimo
Verifica aplicabilidad amplia: Logra aceleración de muestreo conjunto de hasta 20 veces en TNP/PFN y Modelos Fundamentales Tabulares, manteniendo precisión de predicción comparable
Optimiza complejidad teórica: Reduce la complejidad computacional de O(K(N+K)²) a O(N²+NK+K²)
Dado un conjunto de contexto C = {(xₙ, yₙ)}ᴺₙ₌₁ y un conjunto de objetivos T = {(xₘ, yₘ)}ᴹₘ₌₁, el objetivo es aprender la distribución de predicción p_θ(y₁:ₘ|x₁:ₘ; C), donde θ son los parámetros del modelo.
Codificador de contexto rC: Procesa pares de contexto, utilizando autoatención multi-cabeza bidireccional, almacenando en caché pares clave-valor en cada capa
Codificador de búfer rB: Aplica autoatención multi-cabeza estrictamente causal al prefijo del búfer
Decodificador de objetivo rtgt: Consulta el contexto almacenado en caché y el prefijo de búfer visible mediante atención cruzada
Impacto del tamaño de búfer: K=1 equivale a autorregresión estándar, K=16 muestra ligera disminución de rendimiento pero aceleración significativa de velocidad
Núcleos Triton personalizados: Proporcionan aceleración significativa en lotes grandes
Patrones de atención: Incluso con FlashAttention deshabilitado, TNP-A sigue siendo órdenes de magnitud más lento que otros métodos
Este método funciona como componente modular que puede integrarse en arquitecturas NP/PFN existentes. Complementa trabajo previo enfocado en escalabilidad de conjuntos de contexto, abordando eficiencia de muestreo conjunto autorregresivo.
Construye sobre la tendencia de enmarcar inferencia bayesiana como tareas de aprendizaje en contexto, aprovechando variantes NP y PFN basadas en Transformer.
Relacionado a TNP-A pero con diferencias clave: TNP-A usa repetición de objetivos tanto en entrenamiento como inferencia, mientras este método solo la requiere en evaluación de verosimilitud.
Bruinsma et al. (2023): Procesos Neurales Condicionales Autorregresivos
Jingang et al. (2025): Modelo Fundamental Tabular TabICL
Evaluación General: Este es un artículo de investigación de alta calidad que demuestra excelencia en innovación teórica, verificación experimental e implementación de ingeniería. El método resuelve exitosamente un cuello de botella de eficiencia importante en modelos probabilísticos, con amplias perspectivas de aplicación y valor académico.