Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond
Oncescu, Purandare, Idreos et al.
While transformers have been at the core of most recent advancements in sequence generative models, their computational cost remains quadratic in sequence length. Several subquadratic architectures have been proposed to address this computational issue. Some of them, including long convolution sequence models (LCSMs), such as Hyena, address this issue at training time but remain quadratic during inference. We propose a method for speeding up LCSMs' exact inference to quasilinear $O(L\log^2L)$ time, identify the key properties that make this possible, and propose a general framework that exploits these. Our approach, inspired by previous work on relaxed polynomial interpolation, is based on a tiling which helps decrease memory movement and share computation. It has the added benefit of allowing for almost complete parallelization across layers of the position-mixing part of the architecture. Empirically, we provide a proof of concept implementation for Hyena, which gets up to $7.8\times$ end-to-end improvement over standard inference by improving $110\times$ within the position-mixing part.
academic
Flash Inference: Inferencia en Tiempo Casi Lineal para Modelos de Secuencia de Convolución Larga y Más Allá
Este artículo aborda el problema de la complejidad temporal cuadrática en la fase de inferencia de modelos de secuencia de convolución larga (LCSMs), proponiendo el marco Flash Inference que reduce la complejidad temporal de la inferencia exacta a casi lineal O(Llog2L). El método se inspira en interpolación polinomial relajada (relaxed polynomial interpolation) y se basa en una estrategia de particionamiento (tiling) para reducir el movimiento de memoria y compartir cálculos. Los experimentos en la arquitectura Hyena demuestran una aceleración de 7.8 veces en inferencia de extremo a extremo y 110 veces en la parte de mezcla de posiciones.
Aunque los Transformers han logrado un éxito enorme en modelos de generación de secuencias, su costo computacional crece cuadráticamente con la longitud de la secuencia (O(L2)), lo que se convierte en un cuello de botella tanto en entrenamiento como en inferencia. Para resolver este problema, los investigadores han propuesto múltiples arquitecturas subquadráticas, incluyendo modelos de espacio de estados (SSMs) y modelos de secuencia de convolución larga (LCSMs, como Hyena).
Eficiencia de Entrenamiento Resuelta: Los LCSMs pueden lograr complejidad O(LlogL) durante el entrenamiento mediante FFT
Eficiencia de Inferencia No Resuelta: Durante la inferencia autorregresiva, como la secuencia de entrada se genera paso a paso, no se puede usar FFT directamente, lo que degrada la complejidad a O(L2)
Demanda de Contexto Largo: Con los modelos de lenguaje grande procesando contextos cada vez más largos, el problema de eficiencia de inferencia se vuelve más prominente
Métodos Aproximados (Massaroli et al. 2024): Proyectan el filtro de convolución a un SSM LTI de baja dimensión, pero esto es solo una aproximación y requiere precálculo de destilación costoso, sin soporte para filtros dependientes de datos
Perspectiva Recursiva: Puede ser eficiente para SSMs de baja dimensión, pero sigue siendo ineficiente para SSMs de alta dimensión (dimensión cercana a la longitud de la secuencia)
Métodos de Explotación de Estructura: Requieren que el filtro tenga estructura específica (como SSM LTI de bajo rango), limitando la capacidad expresiva del modelo
Este artículo tiene como objetivo proporcionar un marco de aceleración de inferencia exacto y universal que no dependa de la estructura específica del filtro, mientras que al mismo tiempo soporta filtros dependientes de datos.
Primer Algoritmo de Inferencia Exacta Casi Lineal: Propone un algoritmo de inferencia exacta con complejidad temporal O(Llog2L) para LCSMs, logrando simulación exacta en comparación con métodos aproximados anteriores
Identificación de Marco Universal: Identifica propiedades arquitectónicas clave que hacen posible la inferencia rápida (base de contribución, independencia de consulta), proponiendo el marco Flash Inference aplicable a una clase más amplia de arquitecturas
Paralelización Entre Capas: Utiliza estrategia de particionamiento para lograr paralelización casi completa de cálculos entre capas en la parte de mezcla de posiciones
Optimización de Memoria: Mediante el método de particionamiento, reduce significativamente el movimiento de datos de Ω(L2) a O(LlogL), ahorrando 2 veces el almacenamiento de activaciones para filtros independientes de datos
Verificación Empírica: Logra aceleración de extremo a extremo de 7.8 veces en la arquitectura Hyena, con 110 veces de aceleración en la parte de convolución
Generación de Secuencia Autorregresiva: Dada una secuencia de indicación x1,…,xp, el modelo necesita generar tokens subsecuentes uno por uno. En cada posición i, el modelo calcula activaciones ai[1,M] a través de todas las capas, finalmente muestreando xi+1 desde aiM.
Cuello de Botella de Cálculo Central: Para cada capa ℓ y cada dimensión, es necesario calcular:
zt=∑i=1tyi⋅ρt−i
donde y es la secuencia de entrada y ρ es un filtro de convolución de longitud L. La implementación ingenua requiere tiempo Ω(L2).
para i = 1 a L-1:
U ← la potencia más grande de 2 que divide i
z_i += y_i * ρ_0 # celda roja: dependencia directa
z[i+1:i+U] += τ(y, [i-U+1, i], ρ, [i+1, i+U]) # bloque gris: cálculo ansioso
devolver z_i
desbloquear y_{i+1}
Características Clave:
En la iteración i, calcula un bloque gris con lado U (donde U es la potencia más grande de 2 que divide i)
La celda roja maneja la dependencia directa de la posición actual
El bloque gris calcula anticipadamente parte de la contribución futura
Análisis de Complejidad (Proposición 1):
Para bloques de longitud 2q, hay 2P−1−q llamadas (donde L=2P)
Tiempo total: ∑q=0P−12P−1−q⋅O(2qlog2q)=O(Llog2L)
Memoria: O(L) (pico determinado por el bloque más grande)
Extiende el Algoritmo 1 a múltiples capas y dimensiones:
para i = 1 a L-1:
U ← la potencia más grande de 2 que divide i
para ℓ = 1 a M: # iterar sobre capas
b^ℓ_i += a^{ℓ-1}_i ⊙ ρ^ℓ_0 # celda roja
a^ℓ_i = block^ℓ(b^ℓ_i)
b^ℓ[i+1:i+U] += τ(a^{ℓ-1}, [i-U+1, i], ρ^ℓ, [i+1, i+U]) # bloque gris
a^0_{i+1} = sampler(a^M_i)
El cálculo de bloques grises puede ejecutarse en paralelo a través de todas las capas:
para i = 1 a L-1:
para ℓ = 1 a M:
procesar celdas rojas (debe ser secuencial)
paralelo para ℓ = 1 a M:
procesar bloques grises (puede ser paralelo)
Ventajas:
Los bloques pequeños (87.5% de bloques con tamaño ≤4) típicamente están limitados por latencia de memoria, la paralelización puede saturar el ancho de banda de memoria
Los bloques grandes usan FFT, son intensivos en cálculo, la paralelización mejora el rendimiento
P.1 Basado en Contribución (Contribution-based):
El Mixer funciona mediante agregación de contribuciones:
mixer(y)i=read(agg(cont(y,1,i),…,cont(y,i,i)))
Supongamos que existe un algoritmo A que puede calcular contribuciones de bloque en tiempo T(L1,L2):
A(y,[l,r],[l′,r′])=agg(cont(y,l,p),…,cont(y,r,p))
Teorema 2: Bajo P.1 y P.2, cada capa ejecuta:
L−1 llamadas a A (llamadas 2P−1−q de longitud 2q)
Tiempo total: ∑q=0P−12P−1−qT(2q,2q)
Paralelización entre capas: bloques grises sin dependencia de datos, pueden ser paralelos
CUDA Graphs: Registrar todas las invocaciones de núcleo para generación de un token como gráfico, reproducir posteriormente para reducir sobrecarga de CPU (mejora 10-20%)
Precálculo de FFT: Precalcular DFT del núcleo de convolución para log2(L)−1 tamaños de bloque
Preconfiguración de FlashFFT: Preinicializar configuración para diferentes tamaños de bloque para maximizar rendimiento de hardware
Relleno Derecho: Usar relleno derecho en lugar de izquierdo, reduciendo tiempo de cálculo a la mitad
Convolución Circular: Aprovechar propiedades de convolución circular para reducir longitud de FFT a la mitad
Consistencia Teoría-Práctica: Complejidad O(Llog2L) se refleja en aceleración significativa en experimentos
Importancia del Ancho de Banda de Memoria: Flash Conv1D aunque es complejidad cuadrática, aún logra 5× de aceleración mediante optimización de acceso a memoria
Necesidad de Selección Dinámica: Ninguna implementación única de τ es óptima para todos los tamaños de bloque, estrategia Hybrid es crítica
Sobrecarga de CPU No Despreciable: CUDA Graphs eleva aceleración de extremo a extremo de 1.6× a 8×
Beneficio de Paralelización: Bloques pequeños dominan (87.5%), paralelización entre capas muy efectiva
Contribución Teórica: Primer algoritmo de inferencia exacta O(Llog2L) para LCSMs
Marco Universal: Identificación de propiedades clave (basado en contribución, independencia de consulta), aplicable a arquitecturas más amplias
Verificación Empírica: Aceleración de extremo a extremo 7.8× en Hyena, 110× en parte mixer
Optimización de Sistema: Paralelización entre capas, optimización de memoria, selección dinámica de implementación y otras contribuciones de ingeniería
Diseño de Arquitectura: Diseñar nuevas arquitecturas que satisfagan requisitos de Flash Inference con alta calidad
Filtros Dependientes de Datos Causales: Cómo hacer filtros dependientes de datos mientras se mantiene causalidad (Arora et al., Karami & Ghodsi ya muestran potencial)
Métodos Híbridos: Combinar perspectiva recursiva (dimensión de estado pequeña) y perspectiva de convolución (dimensión de estado grande)
Más Arquitecturas: Extender a otros modelos que satisfagan propiedades del marco (como ciertas variantes de atención)
Inferencia Distribuida: Optimización en escenarios multi-GPU/multi-nodo
Análisis de Complejidad Completo: Desde Lema 1 hasta Teorema 2, cadena de prueba clara
Abstracción de Marco Universal: Propiedades P.1 y P.2 abstraídas apropiadamente, incluyendo LCSMs pero excluyendo casos inaplicables (como Transformer)
Selección de Herramientas Matemáticas: Aplicación ingeniosa de teoría de interpolación polinomial relajada
van der Hoeven, J. (1997). Lazy multiplication of formal power series. ISSAC. Fundamento Teórico
Poli, M. et al. (2023). Hyena hierarchy: Towards larger convolutional language models. Objeto Principal de Aplicación
Massaroli, S. et al. (2024). Laughing hyena distillery: Extracting compact recurrences from convolutions. NeurIPS. Comparación de Método Aproximado
Gu, A. & Dao, T. (2023). Mamba: Linear-time sequence modeling with selective state spaces. Trabajo Relacionado SSM
Fu, D. Y. et al. (2023). FlashFFTConv: Efficient convolutions for long sequences with tensor cores. Base de Implementación
Agarwal, N. et al. (2024). FutureFill: Fast generation from convolutional sequence models. Trabajo Paralelo
Evaluación General: Este es un artículo excelente que combina estrechamente teoría y práctica. Teóricamente, proporciona el primer algoritmo de inferencia exacta casi lineal para LCSMs e identifica un marco universal; prácticamente, logra aceleración significativa mediante optimización a nivel de sistema. Las limitaciones principales son que LCSMs en sí no son tan populares como Transformer en aplicaciones reales, y la verificación experimental de filtros dependientes de datos es insuficiente. Este trabajo proporciona una nueva perspectiva para optimización de inferencia de modelos de secuencia, particularmente valiosa para diseño de arquitectura futura. Recomendado para investigadores interesados en eficiencia de modelos, modelado de secuencias y optimización de sistemas.