Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond
Oncescu, Purandare, Idreos et al.
While transformers have been at the core of most recent advancements in sequence generative models, their computational cost remains quadratic in sequence length. Several subquadratic architectures have been proposed to address this computational issue. Some of them, including long convolution sequence models (LCSMs), such as Hyena, address this issue at training time but remain quadratic during inference. We propose a method for speeding up LCSMs' exact inference to quasilinear $O(L\log^2L)$ time, identify the key properties that make this possible, and propose a general framework that exploits these. Our approach, inspired by previous work on relaxed polynomial interpolation, is based on a tiling which helps decrease memory movement and share computation. It has the added benefit of allowing for almost complete parallelization across layers of the position-mixing part of the architecture. Empirically, we provide a proof of concept implementation for Hyena, which gets up to $7.8\times$ end-to-end improvement over standard inference by improving $110\times$ within the position-mixing part.
academic
Flash Inference: Inferenza in Tempo Quasi-Lineare per Modelli di Sequenza a Convoluzione Lunga e Oltre
Questo articolo affronta il problema della complessità temporale quadratica durante l'inferenza nei modelli di sequenza a convoluzione lunga (LCSMs), proponendo il framework Flash Inference che riduce la complessità temporale dell'inferenza esatta a quasi-lineare O(Llog2L). Il metodo è ispirato dall'interpolazione polinomiale rilassata e si basa su strategie di tiling per ridurre il movimento in memoria e condividere i calcoli. Gli esperimenti sull'architettura Hyena dimostrano un'accelerazione end-to-end di 7,8 volte e un'accelerazione di 110 volte per la parte di mixing posizionale.
Sebbene i Transformer abbiano ottenuto enorme successo nei modelli di generazione di sequenze, il loro costo computazionale cresce quadraticamente con la lunghezza della sequenza (O(L2)), diventando un collo di bottiglia sia durante l'addestramento che l'inferenza. Per risolvere questo problema, i ricercatori hanno proposto diverse architetture sub-quadratiche, inclusi i modelli di spazio degli stati (SSMs) e i modelli di sequenza a convoluzione lunga (LCSMs, come Hyena).
Efficienza di Addestramento Risolta: gli LCSMs possono raggiungere complessità O(LlogL) durante l'addestramento tramite FFT
Efficienza di Inferenza Non Risolta: durante l'inferenza autoregressiva, poiché la sequenza di input viene generata progressivamente, non è possibile utilizzare direttamente l'FFT, causando una degradazione della complessità a O(L2)
Necessità di Contesti Lunghi: con i modelli di linguaggio di grandi dimensioni che elaborano contesti sempre più lunghi, il problema dell'efficienza di inferenza diventa sempre più critico
Metodi Approssimativi (Massaroli et al. 2024): proiettano il filtro di convoluzione in un SSM LTI a bassa dimensione, ma si tratta solo di un'approssimazione e richiede una precomputazione di distillazione costosa, non supporta filtri dipendenti dai dati
Prospettiva Ricorsiva: potrebbe essere efficiente per SSM a bassa dimensione, ma rimane inefficiente per SSM ad alta dimensione (dimensione prossima alla lunghezza della sequenza)
Metodi di Sfruttamento della Struttura: richiedono che il filtro abbia una struttura specifica (come SSM LTI a basso rango), limitando la capacità espressiva del modello
Questo articolo mira a fornire un framework di accelerazione dell'inferenza esatto e universale che non dipenda dalla struttura specifica del filtro, supportando al contempo filtri dipendenti dai dati.
Primo Algoritmo di Inferenza Esatto Quasi-Lineare: propone un algoritmo di inferenza esatto con complessità temporale O(Llog2L) per gli LCSMs, realizzando una simulazione esatta rispetto ai metodi approssimativi precedenti
Identificazione di Framework Universale: identifica le proprietà architetturali chiave che rendono possibile l'inferenza veloce (base di contribuzione, indipendenza dalla query), proponendo il framework Flash Inference applicabile a una classe più ampia di architetture
Parallelizzazione Cross-Layer: sfrutta strategie di tiling per realizzare il calcolo quasi completamente parallelo cross-layer della parte di mixing posizionale
Ottimizzazione della Memoria: attraverso il metodo di tiling riduce significativamente il movimento dei dati da Ω(L2) a O(LlogL), risparmiando 2 volte l'archiviazione di attivazione per filtri indipendenti dai dati
Verifica Empirica: realizza un'accelerazione end-to-end di 7,8 volte sull'architettura Hyena, con 110 volte di accelerazione per la parte di convoluzione
Generazione di Sequenza Autoregressiva: data una sequenza di prompt x1,…,xp, il modello deve generare i token successivi uno per uno. Ad ogni posizione i, il modello calcola le attivazioni ai[1,M] attraverso tutti gli strati, infine campionando xi+1 da aiM.
Collo di Bottiglia Computazionale: per ogni strato ℓ e ogni dimensione, è necessario calcolare:
zt=∑i=1tyi⋅ρt−i
dove y è la sequenza di input e ρ è il filtro di convoluzione di lunghezza L. L'implementazione ingenua richiede tempo Ω(L2).
for i = 1 to L-1:
U ← massima potenza di 2 che divide i
for ℓ = 1 to M: # iterazione su strati
b^ℓ_i += a^{ℓ-1}_i ⊙ ρ^ℓ_0 # cella rossa
a^ℓ_i = block^ℓ(b^ℓ_i)
b^ℓ[i+1:i+U] += τ(a^{ℓ-1}, [i-U+1, i], ρ^ℓ, [i+1, i+U]) # blocco grigio
a^0_{i+1} = sampler(a^M_i)
Complessità (Proposizione 2):
Parte Mixer: O(MDLlog2L)
Parte Block: LM invocazioni (generalmente O(MLD2))
Il calcolo dei blocchi grigi può essere eseguito in parallelo su tutti gli strati:
for i = 1 to L-1:
for ℓ = 1 to M:
elaborazione celle rosse (deve essere sequenziale)
parallel for ℓ = 1 to M:
elaborazione blocchi grigi (può essere parallela)
Vantaggi:
I blocchi piccoli (87,5% dei blocchi hanno dimensione ≤4) sono generalmente limitati dalla latenza della memoria, la parallelizzazione può saturare la larghezza di banda della memoria
I blocchi grandi utilizzano FFT, sono computazionalmente intensivi, la parallelizzazione migliora il throughput
La convoluzione FFT standard richiede FFT di lunghezza 4U (lunghezza di output 3U-1)
Questo articolo richiede solo convoluzione circolare di lunghezza 2U (l'intervallo di output di interesse [U,2U−1] non è influenzato dalla circolarità)
P.1 Base di Contribuzione (Contribution-based):
Il Mixer funziona attraverso aggregazione di contribuzioni:
mixer(y)i=read(agg(cont(y,1,i),…,cont(y,i,i)))
Assumendo l'esistenza di un algoritmo A che possa calcolare il contributo di blocco in tempo T(L1,L2):
A(y,[l,r],[l′,r′])=agg(cont(y,l,p),…,cont(y,r,p))
Teorema 2: sotto P.1 e P.2, ogni strato esegue:
L−1 invocazioni di A (2P−1−q invocazioni di lunghezza 2q)
Tempo totale: ∑q=0P−12P−1−qT(2q,2q)
Parallelizzazione cross-layer: i blocchi grigi non hanno dipendenze di dati, possono essere parallelizzati
CUDA Graphs: registra la pianificazione di tutti i kernel per la generazione di un singolo token come grafico, successivamente riproduce per ridurre l'overhead della CPU (miglioramento 10-20%)
Precomputazione FFT: precomputa la DFT del kernel di convoluzione per log2(L)−1 diverse dimensioni di blocco
Preconfigurazione FlashFFT: preinizia le configurazioni per diverse dimensioni di blocco per massimizzare le prestazioni hardware
Padding Destro: utilizza padding destro anziché sinistro, riducendo il tempo di calcolo della metà
Convoluzione Circolare: sfrutta la proprietà di convoluzione circolare per ridurre la lunghezza FFT della metà
Coerenza Teoria-Pratica: la complessità O(Llog2L) si manifesta come accelerazione significativa negli esperimenti
Importanza della Larghezza di Banda della Memoria: Flash Conv1D, sebbene quadratico, ottiene 5 volte di accelerazione attraverso l'ottimizzazione dell'accesso alla memoria
Necessità della Selezione Dinamica: nessuna singola implementazione di τ è ottimale per tutte le dimensioni di blocco, la strategia Hybrid è cruciale
Overhead della CPU Non Trascurabile: CUDA Graphs migliora l'accelerazione end-to-end da 1,6× a 8×
Benefici della Parallelizzazione: i blocchi piccoli dominano (87,5%), la parallelizzazione cross-layer è significativa
Filtri Dipendenti dai Dati: sebbene teoricamente supportati, richiedono 2 volte il calcolo, non completamente verificati negli esperimenti
Requisiti di Memoria: ancora necessario memorizzare tutte le attivazioni O(MLD) (vs prospettiva ricorsiva di O(MD′))
Ambito di Applicabilità:
Non applicabile a Transformer (non soddisfa indipendenza dalla query)
Per SSM a dimensione molto bassa (D′≪log2L), la prospettiva ricorsiva potrebbe essere più ottimale
Fase di Prompt: con prompt lunghi, il prefill (precompilazione) domina ancora il tempo, il beneficio relativo dell'ottimizzazione autoregressiva è limitato
Dipendenza dall'Hardware: l'effetto di accelerazione dipende dalle caratteristiche della larghezza di banda della memoria GPU
Progettazione di Architetture: progettare nuove architetture che soddisfino i requisiti di Flash Inference e mantengano alta qualità
Filtri Dipendenti dai Dati Causali: come rendere il filtro dipendente dai dati mantenendo la causalità (Arora et al., Karami & Ghodsi hanno mostrato potenziale)
Metodi Ibridi: combinare la prospettiva ricorsiva (dimensione dello stato piccola) e la prospettiva di convoluzione (dimensione dello stato grande)
Più Architetture: estendere a altre architetture che soddisfano le proprietà del framework (come alcune varianti di attenzione)
Inferenza Distribuita: ottimizzazioni per scenari multi-GPU/multi-nodo
Analisi di Complessità Completa: dalla Lemma 1 al Teorema 2, la catena di prove è chiara
Astrazione di Framework Universale: le proprietà P.1 e P.2 sono astratte appropriatamente, includono gli LCSMs ed escludono i casi non applicabili (come i Transformer)
Scelta di Strumenti Matematici: l'applicazione della teoria dell'interpolazione polinomiale rilassata è ingegnosa
Strategia di Tiling: il tiling rettangolare bilanciato (vs strisce sottili) è un'intuizione chiave
Parallelizzazione Cross-Layer: l'identificazione che i blocchi grigi non hanno dipendenze rompe il limite dell'esecuzione sequenziale tradizionale per strati
Selezione Dinamica di Implementazioni: la strategia Hybrid riflette una profonda comprensione delle caratteristiche hardware
Con Metodi Approssimativi: nessun confronto sperimentale con il compromesso qualità-velocità di Massaroli et al.
Con Prospettiva Ricorsiva: analisi quantitativa insufficiente su quando la prospettiva ricorsiva è più ottimale (solo discussione qualitativa di D′∈O(log2L))
Con Sfruttamento della Struttura: nessun confronto con metodi di struttura specifica come convoluzione dilatata
van der Hoeven, J. (1997). Lazy multiplication of formal power series. ISSAC. Fondamenti Teorici
Poli, M. et al. (2023). Hyena hierarchy: Towards larger convolutional language models. Oggetto di Applicazione Principale
Massaroli, S. et al. (2024). Laughing hyena distillery: Extracting compact recurrences from convolutions. NeurIPS. Confronto Metodi Approssimativi
Gu, A. & Dao, T. (2023). Mamba: Linear-time sequence modeling with selective state spaces. Lavori Correlati SSM
Fu, D. Y. et al. (2023). FlashFFTConv: Efficient convolutions for long sequences with tensor cores. Fondamenti di Implementazione
Agarwal, N. et al. (2024). FutureFill: Fast generation from convolutional sequence models. Lavori Paralleli
Valutazione Complessiva: questo è un articolo eccellente che combina strettamente teoria e pratica. Dal punto di vista teorico, fornisce il primo algoritmo di inferenza esatto quasi-lineare per gli LCSMs e astrae un framework universale; dal punto di vista pratico, realizza accelerazioni significative attraverso ottimizzazioni a livello di sistema. Le limitazioni principali risiedono nel fatto che gli LCSMs stessi non sono così diffusi nelle applicazioni pratiche come i Transformer, e la verifica sperimentale dei filtri dipendenti dai dati è insufficiente. Questo lavoro fornisce una nuova prospettiva per l'ottimizzazione dell'inferenza di modelli di sequenza, in particolare con significato guida per la progettazione di architetture future. Consigliato ai ricercatori interessati all'efficienza dei modelli, alla modellazione di sequenze e all'ottimizzazione di sistema.