Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond
Oncescu, Purandare, Idreos et al.
While transformers have been at the core of most recent advancements in sequence generative models, their computational cost remains quadratic in sequence length. Several subquadratic architectures have been proposed to address this computational issue. Some of them, including long convolution sequence models (LCSMs), such as Hyena, address this issue at training time but remain quadratic during inference. We propose a method for speeding up LCSMs' exact inference to quasilinear $O(L\log^2L)$ time, identify the key properties that make this possible, and propose a general framework that exploits these. Our approach, inspired by previous work on relaxed polynomial interpolation, is based on a tiling which helps decrease memory movement and share computation. It has the added benefit of allowing for almost complete parallelization across layers of the position-mixing part of the architecture. Empirically, we provide a proof of concept implementation for Hyena, which gets up to $7.8\times$ end-to-end improvement over standard inference by improving $110\times$ within the position-mixing part.
academic
Flash Inference : Inférence en Temps Quasi-Linéaire pour les Modèles de Séquences à Convolution Longue et Au-Delà
Cet article propose le cadre Flash Inference pour résoudre le problème de complexité temporelle quadratique lors de l'inférence des modèles de séquences à convolution longue (LCSMs), réduisant la complexité temporelle de l'inférence exacte à quasi-linéaire O(Llog2L). La méthode s'inspire de l'interpolation polynomiale relaxée et repose sur une stratégie de partitionnement (tiling) pour réduire les mouvements mémoire et partager les calculs. Les expériences sur l'architecture Hyena démontrent une accélération d'inférence bout-à-bout de 7,8× et une accélération de 110× pour la partie mélange de positions.
Bien que les Transformers aient remporté un succès considérable dans les modèles de génération de séquences, leur coût de calcul croît quadratiquement avec la longueur de la séquence (O(L2)), ce qui constitue un goulot d'étranglement tant en phase d'entraînement qu'en phase d'inférence. Pour résoudre ce problème, les chercheurs ont proposé diverses architectures sous-quadratiques, notamment les modèles d'espace d'état (SSMs) et les modèles de séquences à convolution longue (LCSMs, tels que Hyena).
Efficacité d'entraînement résolue : Les LCSMs peuvent atteindre une complexité de O(LlogL) lors de l'entraînement via FFT
Efficacité d'inférence non résolue : Lors de l'inférence autorégressive, puisque la séquence d'entrée est générée progressivement, la FFT ne peut pas être utilisée directement, ce qui entraîne une dégradation de la complexité à O(L2)
Demande de contexte long : Avec les grands modèles de langage traitant des contextes de plus en plus longs, le problème d'efficacité d'inférence devient plus critique
Méthodes d'approximation (Massaroli et al. 2024) : Projettent le filtre de convolution dans un SSM LTI de faible dimension, mais il s'agit seulement d'une approximation nécessitant un prétraitement de distillation coûteux et ne supportant pas les filtres dépendants des données
Perspective récursive : Peut être efficace pour les SSMs de faible dimension, mais reste inefficace pour les SSMs de haute dimension (dimension proche de la longueur de la séquence)
Méthodes exploitant la structure : Nécessitent que le filtre possède une structure spécifique (comme un SSM LTI de faible rang), limitant la capacité d'expression du modèle
Cet article vise à fournir un cadre d'accélération d'inférence exact et universel, indépendant de la structure spécifique du filtre, tout en supportant les filtres dépendants des données.
Premier algorithme d'inférence exacte quasi-linéaire : Propose un algorithme d'inférence exacte avec complexité temporelle O(Llog2L) pour les LCSMs, réalisant une simulation exacte par rapport aux méthodes d'approximation antérieures
Identification d'un cadre universel : Identifie les propriétés architecturales clés permettant une inférence rapide (base de contribution, indépendance de requête) et propose le cadre Flash Inference applicable à une classe plus large d'architectures
Parallélisation inter-couches : Exploite la stratégie de partitionnement pour réaliser un calcul presque complètement parallèle inter-couches de la partie mélange de positions
Optimisation mémoire : Réduit significativement le mouvement de données via la méthode de partitionnement, de Ω(L2) à O(LlogL), économisant 2× le stockage d'activations pour les filtres indépendants des données
Validation empirique : Réalise une accélération bout-à-bout de 7,8× sur l'architecture Hyena et une accélération de 110× pour la partie convolution
Génération de séquences autorégressive : Étant donné une séquence d'amorce x1,…,xp, le modèle doit générer les tokens suivants un par un. À chaque position i, le modèle calcule les activations ai[1,M] à travers toutes les couches, puis échantillonne xi+1 à partir de aiM.
Goulot d'étranglement de calcul : Pour chaque couche ℓ et chaque dimension, il faut calculer :
zt=∑i=1tyi⋅ρt−i
où y est la séquence d'entrée et ρ est un filtre de convolution de longueur L. L'implémentation naïve nécessite un temps Ω(L2).
for i = 1 to L-1:
U ← la plus grande puissance de 2 divisant i
z_i += y_i * ρ_0 # cellule rouge : dépendance directe
z[i+1:i+U] += τ(y, [i-U+1, i], ρ, [i+1, i+U]) # bloc gris : calcul enthousiaste
return z_i
unlock y_{i+1}
Caractéristiques clés :
À la i-ième itération, calcule un bloc gris de côté U (où U est la plus grande puissance de 2 divisant i)
La cellule rouge traite la dépendance directe de la position actuelle
Le bloc gris calcule à l'avance une partie des contributions futures
Analyse de Complexité (Proposition 1) :
Pour les blocs de longueur 2q, il y a 2P−1−q appels (où L=2P)
Temps total : ∑q=0P−12P−1−q⋅O(2qlog2q)=O(Llog2L)
Mémoire : O(L) (pic déterminé par le bloc maximal)
Étend l'Algorithme 1 à plusieurs couches et dimensions :
for i = 1 to L-1:
U ← la plus grande puissance de 2 divisant i
for ℓ = 1 to M: # parcourir les couches
b^ℓ_i += a^{ℓ-1}_i ⊙ ρ^ℓ_0 # cellule rouge
a^ℓ_i = block^ℓ(b^ℓ_i)
b^ℓ[i+1:i+U] += τ(a^{ℓ-1}, [i-U+1, i], ρ^ℓ, [i+1, i+U]) # bloc gris
a^0_{i+1} = sampler(a^M_i)
Le calcul des blocs gris peut être exécuté en parallèle sur toutes les couches :
for i = 1 to L-1:
for ℓ = 1 to M:
traiter les cellules rouges (doit être séquentiel)
parallel for ℓ = 1 to M:
traiter les blocs gris (peut être parallèle)
Avantages :
Les petits blocs (87,5% des blocs de taille ≤4) sont généralement limités par la latence mémoire, la parallélisation peut saturer la bande passante mémoire
Les grands blocs utilisent FFT, intensifs en calcul, la parallélisation améliore le débit
La convolution FFT standard nécessite une FFT de longueur 4U (longueur de sortie 3U-1)
Cet article n'a besoin que d'une convolution circulaire de longueur 2U (la plage de sortie intéressante [U,2U−1] n'est pas affectée par la circularité)
P.2 Indépendance de Requête (Query-independent) :
cont(y,i,j) ne dépend pas de y[i+1,L] (les LCSMs satisfont cette propriété, les Transformers ne la satisfont pas)
Supposer qu'il existe un algorithme A capable de calculer les contributions de bloc en temps T(L1,L2) :
A(y,[l,r],[l′,r′])=agg(cont(y,l,p),…,cont(y,r,p))
Théorème 2 : Sous P.1 et P.2, chaque couche exécute :
L−1 appels à A (avec 2P−1−q appels de longueur 2q)
Temps total : ∑q=0P−12P−1−qT(2q,2q)
Parallélisation inter-couches : les blocs gris n'ont pas de dépendances de données, peuvent être parallélisés
CUDA Graphs : enregistre tous les noyaux d'une génération de token unique en tant que graphe, rejoué ultérieurement pour réduire les frais généraux CPU (amélioration de 10-20%)
Prétraitement FFT : prétraite la DFT du noyau de convolution pour log2(L)−1 tailles de bloc
Préconfiguration FlashFFT : préinitialise les configurations pour différentes tailles de bloc afin de maximiser les performances matérielles
Remplissage à droite : utilise le remplissage à droite plutôt que le remplissage à gauche, réduisant de moitié le temps de calcul
Convolution circulaire : exploite la propriété de convolution circulaire pour réduire de moitié la longueur FFT
Cohérence théorie-pratique : la complexité O(Llog2L) se manifeste par une accélération significative dans les expériences
Importance de la bande passante mémoire : Flash Conv1D bien que quadratique, obtient 5× d'accélération via optimisation d'accès mémoire
Nécessité de la sélection dynamique : aucune implémentation unique de τ n'est optimale pour toutes les tailles de bloc, la stratégie Hybrid est cruciale
Frais généraux CPU non négligeables : CUDA Graphs améliore l'accélération bout-à-bout de 1,6× à 8×
Bénéfices de la parallélisation : les petits blocs dominent (87,5%), la parallélisation inter-couches est très efficace
Filtres dépendants des données : bien que théoriquement supportés, nécessitent 2× le calcul, validation expérimentale insuffisante
Besoins mémoire : nécessite toujours le stockage de toutes les activations O(MLD) (vs perspective récursive O(MD′))
Portée d'application :
Inapplicable aux Transformers (ne satisfait pas l'indépendance de requête)
Pour les SSMs très faible dimension (D′≪log2L), la perspective récursive peut être plus optimale
Phase de prompt : avec des prompts longs, le prétraitement (prefill) domine toujours le temps, l'optimisation de la génération autorégressive a un bénéfice relatif limité
Dépendance matérielle : l'effet d'accélération dépend des caractéristiques de bande passante mémoire du GPU
Conception architecturale : concevoir de nouvelles architectures satisfaisant les exigences de Flash Inference avec haute qualité
Filtres dépendants des données causaux : comment rendre les filtres dépendants des données tout en maintenant la causalité (Arora et al., Karami & Ghodsi ont montré du potentiel)
Approches hybrides : combiner la perspective récursive (faible dimension d'état) et la perspective convolutive (haute dimension d'état)
Plus d'architectures : étendre à d'autres modèles satisfaisant les propriétés du cadre (comme certaines variantes d'attention)
Inférence distribuée : optimisations pour scénarios multi-GPU/multi-nœuds
Analyse de complexité complète : du Lemme 1 au Théorème 2, chaîne de preuve claire
Abstraction du cadre universel : les propriétés P.1 et P.2 sont abstraites de manière appropriée, incluant les LCSMs tout en excluant les cas inapplicables (comme les Transformers)
Choix d'outils mathématiques : application ingénieuse de la théorie de l'interpolation polynomiale relaxée
Stratégie de partitionnement : le partitionnement rectangulaire équilibré (vs bandes étroites) est une intuition clé
Parallélisation inter-couches : identifie que les blocs gris n'ont pas de dépendances, dépassant les limites d'exécution séquentielle traditionnelle des couches
Sélection dynamique d'implémentation : la stratégie Hybrid reflète une compréhension profonde des caractéristiques matérielles
Vs méthodes d'approximation : pas de comparaison expérimentale du compromis qualité-vitesse avec Massaroli et al.
Vs perspective récursive : analyse quantitative insuffisante sur quand la perspective récursive est plus optimale (seulement discussion qualitative sur D′∈O(log2L))
Vs exploitation de structure : pas de comparaison avec des méthodes de structure spécifique comme les convolutions dilatées
van der Hoeven, J. (1997). Lazy multiplication of formal power series. ISSAC. Fondations théoriques
Poli, M. et al. (2023). Hyena hierarchy: Towards larger convolutional language models. Principal objet d'application
Massaroli, S. et al. (2024). Laughing hyena distillery: Extracting compact recurrences from convolutions. NeurIPS. Comparaison des méthodes d'approximation
Gu, A. & Dao, T. (2023). Mamba: Linear-time sequence modeling with selective state spaces. Travaux connexes SSM
Fu, D. Y. et al. (2023). FlashFFTConv: Efficient convolutions for long sequences with tensor cores. Fondations d'implémentation
Agarwal, N. et al. (2024). FutureFill: Fast generation from convolutional sequence models. Travaux parallèles
Évaluation Globale : Ceci est un excellent article combinant étroitement théorie et pratique. Sur le plan théorique, il fournit le premier algorithme d'inférence exacte quasi-linéaire pour les LCSMs et abstrait un cadre universel ; sur le plan pratique, il réalise une accélération significative via des optimisations au niveau système. Les principales limitations résident dans le fait que les LCSMs eux-mêmes ne sont pas aussi largement adoptés que les Transformers dans les applications réelles, et la validation expérimentale des filtres dépendants des données est insuffisante. Ce travail offre une nouvelle perspective sur l'optimisation d'inférence des modèles de séquences, particulièrement instructif pour la conception future d'architectures. Recommandé aux chercheurs intéressés par l'efficacité des modèles, la modélisation de séquences et l'optimisation système.