Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond
Oncescu, Purandare, Idreos et al.
While transformers have been at the core of most recent advancements in sequence generative models, their computational cost remains quadratic in sequence length. Several subquadratic architectures have been proposed to address this computational issue. Some of them, including long convolution sequence models (LCSMs), such as Hyena, address this issue at training time but remain quadratic during inference. We propose a method for speeding up LCSMs' exact inference to quasilinear $O(L\log^2L)$ time, identify the key properties that make this possible, and propose a general framework that exploits these. Our approach, inspired by previous work on relaxed polynomial interpolation, is based on a tiling which helps decrease memory movement and share computation. It has the added benefit of allowing for almost complete parallelization across layers of the position-mixing part of the architecture. Empirically, we provide a proof of concept implementation for Hyena, which gets up to $7.8\times$ end-to-end improvement over standard inference by improving $110\times$ within the position-mixing part.
academic
Flash Inference: Nahezu lineare Inferenzzeit für lange Faltungs-Sequenzmodelle und darüber hinaus
Dieses Paper adressiert das Problem der quadratischen Zeitkomplexität bei der Inferenz von langen Faltungs-Sequenzmodellen (LCSMs) und schlägt das Flash Inference-Framework vor, das die Zeitkomplexität für exakte Inferenz auf quasi-linear O(Llog2L) reduziert. Die Methode wird durch relaxierte Polynominterpolation inspiriert und basiert auf einer Kachelung-Strategie (Tiling), um Speicherbewegungen zu reduzieren und Berechnungen zu teilen. Experimente auf der Hyena-Architektur zeigen eine End-to-End-Beschleunigung von 7,8× und eine Beschleunigung von 110× für den Positionsmischungs-Teil.
Obwohl Transformer in Sequenzgenerierungsmodellen großen Erfolg hatten, wächst ihr Rechenaufwand quadratisch mit der Sequenzlänge (O(L2)), was sowohl in der Trainings- als auch in der Inferenzphase ein Engpass ist. Um dieses Problem zu lösen, haben Forscher verschiedene subquadratische Architekturen vorgeschlagen, einschließlich State-Space-Modelle (SSMs) und lange Faltungs-Sequenzmodelle (LCSMs, wie Hyena).
Trainingseffizienz gelöst: LCSMs können durch FFT während des Trainings eine Komplexität von O(LlogL) erreichen
Inferenzeffizienz ungelöst: Bei autoregressiver Inferenz kann FFT nicht direkt verwendet werden, da die Eingabesequenz schrittweise generiert wird, was die Komplexität auf O(L2) verschlechtert
Anforderung für lange Kontexte: Mit großen Sprachmodellen, die immer längere Kontexte verarbeiten, wird das Inferenzeffizienzproblem immer kritischer
Approximationsmethoden (Massaroli et al. 2024): Projizieren Faltungsfilter auf niedrigdimensionale LTI-SSMs, aber dies ist nur eine Approximation und erfordert teure Destillations-Vorberechnung, unterstützt keine datenabhängigen Filter
Rekursive Perspektive: Kann für niedrigdimensionale SSMs effizient sein, ist aber für hochdimensionale SSMs (Dimension nahe der Sequenzlänge) immer noch ineffizient
Strukturnutzungsmethoden: Erfordern, dass Filter eine bestimmte Struktur haben (z.B. niedrig-rangige LTI-SSMs), was die Modellausdruckskraft einschränkt
Dieses Paper zielt darauf ab, ein exaktes und universelles Inferenzbeschleunigungsframework bereitzustellen, das nicht von der spezifischen Struktur des Filters abhängt und gleichzeitig datenabhängige Filter unterstützt.
Erster quasi-linearer exakter Inferenz-Algorithmus: Schlägt einen Inferenz-Algorithmus mit O(Llog2L) Zeitkomplexität für LCSMs vor, der im Gegensatz zu früheren Approximationsmethoden exakte Simulation erreicht
Universelle Framework-Identifikation: Identifiziert Schlüssel-Architektur-Eigenschaften, die schnelle Inferenz ermöglichen (Beitrag-basiert, Abfrage-unabhängig), und schlägt das Flash Inference-Framework vor, das auf eine breitere Architektur-Klasse anwendbar ist
Schicht-übergreifende Parallelisierung: Nutzt Kachelung-Strategie, um nahezu vollständige schicht-übergreifende parallele Berechnung des Positionsmischungs-Teils zu ermöglichen
Speicheroptimierung: Reduziert Datenbewegung durch Kachelung-Methoden signifikant von Ω(L2) auf O(LlogL), spart 2× Aktivierungsspeicher für datenunabhängige Filter
Empirische Validierung: Erreicht 7,8× End-to-End-Beschleunigung auf Hyena-Architektur, 110× Beschleunigung für Faltungsteil
Autoregressive Sequenzgenerierung: Gegeben eine Eingabesequenz x1,…,xp, muss das Modell nachfolgende Token schrittweise generieren. Bei jeder Position i berechnet das Modell Aktivierungen ai[1,M] durch alle Schichten, und sampelt schließlich xi+1 aus aiM.
Kern-Berechnungsengpass: Für jede Schicht ℓ und jede Dimension muss berechnet werden:
zt=∑i=1tyi⋅ρt−i
wobei y die Eingabesequenz ist und ρ ein Faltungsfilter der Länge L ist. Die naive Implementierung benötigt Ω(L2) Zeit.
for i = 1 to L-1:
U ← größte Potenz von 2, die i teilt
z_i += y_i * ρ_0 # Rote Zelle: direkte Abhängigkeit
z[i+1:i+U] += τ(y, [i-U+1, i], ρ, [i+1, i+U]) # Grauer Block: eifrige Berechnung
return z_i
unlock y_{i+1}
Schlüssel-Eigenschaften:
In der i-ten Iteration wird ein grauer Block mit Kantenlänge U berechnet (wobei U die größte Potenz von 2 ist, die i teilt)
Rote Zellen behandeln direkte Abhängigkeiten der aktuellen Position
Graue Blöcke berechnen im Voraus teilweise zukünftige Beiträge
Komplexitätsanalyse (Proposition 1):
Für Blöcke der Länge 2q gibt es 2P−1−q Aufrufe (wobei L=2P)
Gesamtzeit: ∑q=0P−12P−1−q⋅O(2qlog2q)=O(Llog2L)
Speicher: O(L) (Spitzenwert durch größten Block bestimmt)
Erweitert Algorithmus 1 auf mehrere Schichten und Dimensionen:
for i = 1 to L-1:
U ← größte Potenz von 2, die i teilt
for ℓ = 1 to M: # Schichten durchlaufen
b^ℓ_i += a^{ℓ-1}_i ⊙ ρ^ℓ_0 # Rote Zelle
a^ℓ_i = block^ℓ(b^ℓ_i)
b^ℓ[i+1:i+U] += τ(a^{ℓ-1}, [i-U+1, i], ρ^ℓ, [i+1, i+U]) # Grauer Block
a^0_{i+1} = sampler(a^M_i)
Graue Block-Berechnungen können über alle Schichten hinweg parallel ausgeführt werden:
for i = 1 to L-1:
for ℓ = 1 to M:
Rote Zellen verarbeiten (muss sequenziell sein)
parallel for ℓ = 1 to M:
Graue Blöcke verarbeiten (kann parallel sein)
Vorteile:
Kleine Blöcke (87,5% der Blöcke haben Größe ≤4) sind normalerweise speicherlatenzbegrenzt; Parallelisierung kann Speicherbandbreite sättigen
Große Blöcke verwenden FFT-Implementierung, sind rechenlastig; Parallelisierung verbessert Durchsatz
Durch Modifikation der Kachelung-Strategie (Algorithmus 5) wird die Unterstützung von Fällen, in denen ρ datenabhängig ist, mit 2× Berechnungsaufwand ermöglicht.
CUDA Graphs: Zeichnet alle Kernel-Aufrufe für einzelne Token-Generierung als Graph auf, spielt später ab um CPU-Overhead zu reduzieren (verbessert 10-20%)
FFT-Vorberechnung: Berechnet DFT von Faltungskernen für log2(L)−1 Blockgrößen vor
FlashFFT-Vorkonfiguration: Initialisiert Konfigurationen für verschiedene Blockgrößen vor, um Hardware-Performance zu maximieren
Rechts-Padding: Verwendet Rechts-Padding statt Links-Padding, reduziert Berechnungszeit um die Hälfte
Zirkuläre Faltung: Nutzt zirkuläre Faltungs-Eigenschaft um FFT-Länge zu halbieren
Architektur-Design: Entwurf neuer Architekturen, die Flash Inference-Anforderungen erfüllen und hohe Qualität bieten
Kausale datenabhängige Filter: Wie können Filter datenabhängig sein und gleichzeitig Kausalität bewahren (Arora et al., Karami & Ghodsi zeigen Potenzial)
Hybrid-Methoden: Kombination von rekursiver Perspektive (kleine Zustandsdimension) und Faltungs-Perspektive (große Zustandsdimension)
Weitere Architekturen: Erweiterung auf andere Modelle, die Framework-Eigenschaften erfüllen (z.B. bestimmte Attention-Varianten)
Verteilte Inferenz: Optimierungen für Multi-GPU/Multi-Node-Szenarien
van der Hoeven, J. (1997). Lazy multiplication of formal power series. ISSAC. Theoretische Grundlagen
Poli, M. et al. (2023). Hyena hierarchy: Towards larger convolutional language models. Hauptanwendungsobjekt
Massaroli, S. et al. (2024). Laughing hyena distillery: Extracting compact recurrences from convolutions. NeurIPS. Approximations-Methoden-Vergleich
Gu, A. & Dao, T. (2023). Mamba: Linear-time sequence modeling with selective state spaces. SSM-verwandte Arbeiten
Fu, D. Y. et al. (2023). FlashFFTConv: Efficient convolutions for long sequences with tensor cores. Implementierungs-Grundlagen
Agarwal, N. et al. (2024). FutureFill: Fast generation from convolutional sequence models. Parallele Arbeiten
Gesamtbewertung: Dies ist ein ausgezeichnetes Paper, das Theorie und Praxis eng verbindet. Theoretisch bietet es den ersten quasi-linearen exakten Inferenz-Algorithmus für LCSMs und abstrahiert ein universelles Framework; praktisch realisiert es durch systemische Optimierungen signifikante Beschleunigungen. Haupteinschränkungen sind, dass LCSMs selbst in praktischen Anwendungen nicht so verbreitet wie Transformer sind, und dass die experimentelle Validierung datenabhängiger Filter unzureichend ist. Diese Arbeit bietet neue Perspektiven auf Sequenzmodell-Inferenz-Optimierung, besonders wertvoll für zukünftige Architektur-Designs. Empfohlen für Forscher, die sich für Modell-Effizienz, Sequenzmodellierung und Systemoptimierung interessieren.