2025-11-17T15:49:13.397134

FLARE: Fast Low-rank Attention Routing Engine

Puri, Joglekar, Ferguson et al.
The quadratic complexity of self-attention limits its applicability and scalability on large unstructured meshes. We introduce Fast Low-rank Attention Routing Engine (FLARE), a linear complexity self-attention mechanism that routes attention through fixed-length latent sequences. Each attention head performs global communication among $N$ tokens by projecting the input sequence onto a fixed length latent sequence of $M \ll N$ tokens using learnable query tokens. By routing attention through a bottleneck sequence, FLARE learns a low-rank form of attention that can be applied at $O(NM)$ cost. FLARE not only scales to unprecedented problem sizes, but also delivers superior accuracy compared to state-of-the-art neural PDE surrogates across diverse benchmarks. We also release a new additive manufacturing dataset to spur further research. Our code is available at https://github.com/vpuri3/FLARE.py.
academic

FLARE: Fast Low-rank Attention Routing Engine

Informations de base

  • ID de l'article: 2508.12594
  • Titre: FLARE: Fast Low-rank Attention Routing Engine
  • Auteurs: Vedant Puri, Aditya Joglekar, Kevin Ferguson, Yu-hsuan Chen, Yongjie Jessica Zhang, Levent Burak Kara (Carnegie Mellon University)
  • Classification: cs.LG (Apprentissage automatique)
  • Date de publication: 15 octobre 2025 (arXiv v2)
  • Lien de l'article: https://arxiv.org/abs/2508.12594

Résumé

La complexité quadratique des mécanismes d'auto-attention traditionnels limite leur applicabilité et leur scalabilité sur les maillages non structurés à grande échelle. Cet article propose FLARE (Fast Low-rank Attention Routing Engine), un mécanisme d'auto-attention à complexité linéaire qui achemine l'attention via une séquence latente de longueur fixe. Chaque tête d'attention projette la séquence d'entrée sur une séquence latente de longueur fixe M≪N en utilisant des jetons de requête apprenables, réalisant ainsi une communication globale entre N jetons. En acheminant l'attention via une séquence goulot d'étranglement, FLARE apprend une forme d'attention de faible rang qui peut être appliquée au coût O(NM). FLARE non seulement s'étend à des échelles de problèmes sans précédent, mais fournit également une meilleure précision par rapport aux modèles d'agents PDE neuraux de pointe sur plusieurs benchmarks.

Contexte et motivation de la recherche

Contexte du problème

  1. Problème fondamental: Le mécanisme d'auto-attention du Transformer traditionnel possède une complexité temporelle et mémoire O(N²), ce qui limite sévèrement son application sur les maillages non structurés à grande échelle (tels que les nuages de points et les maillages dans les simulations physiques).
  2. Importance de l'application: Dans la modélisation d'agents PDE (équations aux dérivées partielles), chaque point dans un nuage de points 3D est considéré comme un jeton contenant des caractéristiques géométriques et physiques (telles que les coordonnées, les vecteurs normaux, les propriétés matérielles). La simulation de systèmes physiques haute fidélité est trop coûteuse; les modèles d'agents d'apprentissage automatique offrent une alternative d'approximation rapide.
  3. Limitations des méthodes existantes:
    • PerceiverIO: Effectue uniquement un encodage et un décodage uniques; le goulot d'étranglement potentiel peut limiter la précision
    • Transolver: Partage les poids de projection entre les têtes, ne peut pas exploiter les noyaux GPU existants pour l'attention au produit scalaire échelonné
    • LNO: Applique uniquement une projection unique, manque de capacité de modèle profond
  4. Motivation de la recherche: Développer un mécanisme d'attention capable de maintenir la communication globale tout en possédant une complexité linéaire, permettant aux Transformers de traiter des géométries avec des millions de points.

Contributions principales

  1. Mélange de jetons à complexité linéaire: Propose le mécanisme d'auto-attention FLARE, réalisant une complexité linéaire en remplaçant l'auto-attention complète par une projection et reconstruction de faible rang.
  2. Précision supérieure: Sur plusieurs benchmarks PDE, FLARE atteint une précision de prédiction supérieure aux modèles d'agents neuraux leaders avec moins de paramètres et une complexité computationnelle inférieure.
  3. Scalabilité sans précédent: FLARE est entièrement construit sur des primitives d'attention fusionnées standard, assurant une utilisation GPU élevée et supportant l'entraînement bout en bout sur des maillages non structurés avec des millions de points.
  4. Nouveau benchmark de données: Publie un ensemble de données haute résolution à grande échelle pour la fabrication additive métallique destiné à la recherche sur la prédiction de déplacement résiduel.

Détails de la méthode

Définition de la tâche

Étant donné une séquence d'entrée X ∈ R^(N×C), où N est le nombre de jetons et C est la dimension des caractéristiques, FLARE vise à apprendre un mécanisme d'attention à complexité linéaire réalisant une communication efficace entre jetons globale.

Architecture du modèle

Mécanisme principal de FLARE

FLARE introduit M≪N jetons latents apprenables comme goulot d'étranglement pour l'échange d'informations, comprenant deux étapes:

  1. Étape d'encodage: La séquence d'entrée est projetée sur les jetons latents via attention croisée
    Z_h = SDPA(Q_h, K_h, V_h, s=1)
    

    où Q_h ∈ R^(M×D) est la matrice de requête apprenables, K_h, V_h ∈ R^(N×D)
  2. Étape de décodage: Les jetons latents sont projetés sur la séquence d'entrée
    Y_h = SDPA(K_h, Q_h, Z_h, s=1)
    

Matrice de communication de faible rang

L'ensemble du processus est équivalent à:

Y_h = (W_decode,h · W_encode,h) · V_h

où:

  • W_encode,h = softmax(Q_h · K_h^T) ∈ R^(M×N)
  • W_decode,h = softmax(K_h · Q_h^T) ∈ R^(N×M)
  • W_h = W_decode,h · W_encode,h ∈ R^(N×N) est la matrice de communication globale de rang au maximum M

Structure du bloc FLARE

X = X + FLARE(LayerNorm(X))
X = X + ResMLP(LayerNorm(X))

Points d'innovation technique

  1. Projection indépendante par tête: Contrairement à Transolver qui partage les poids de projection, FLARE assigne à chaque tête une tranche différente de jetons latents, permettant à chaque tête d'apprendre des relations d'attention indépendantes.
  2. MLP résiduel profond: Utilise des réseaux résiduels profonds pour la projection clé/valeur, apprenant des interactions de caractéristiques d'ordre supérieur par rapport aux couches linéaires simples.
  3. Conception symétrique encodage-décodage: La symétrie des opérations d'encodage et de décodage favorise un flux d'information stable.
  4. Compatibilité avec les noyaux fusionnés: Entièrement basé sur les opérations SDPA standard, peut exploiter les algorithmes d'optimisation tels que Flash Attention.

Configuration expérimentale

Ensembles de données

L'article évalue 6 ensembles de données de benchmark et 1 nouvel ensemble de données proposé:

Ensemble de donnéesDimensionType de maillageNombre de pointsCaractéristiques entrée/sortieÉchantillons entraînement/test
Elasticity2DNon structuré9722/11000/200
Darcy2DStructuré7,2252/11000/200
Airfoil2DStructuré11,2712/11000/200
Pipe2DStructuré16,6412/11000/200
DrivAerML-40k3DNon structuré40,0003/1387/97
LPBF3DNon structuré1,000-50,0003/11100/290

Métriques d'évaluation

Utilise principalement l'erreur L2 relative:

Relative L2 = ||û - u||₂ / ||u||₂

Méthodes de comparaison

  • Modèles d'attention généraux: Vanilla Transformer, PerceiverIO
  • Agents PDE basés sur l'attention: Transolver, LNO
  • Opérateurs neuraux: GNOT

Détails d'implémentation

  • Optimiseur: AdamW (β₁=0.9, β₂=0.999)
  • Planification du taux d'apprentissage: OneCycleLR, taux d'apprentissage maximal 10⁻³
  • Nombre d'épochs: 500 pour les problèmes 2D, 250 pour LPBF
  • Taille de batch: 2 pour les problèmes 2D, 1 pour les problèmes 3D

Résultats expérimentaux

Résultats principaux

FLARE atteint les résultats optimaux ou quasi-optimaux sur tous les benchmarks:

ModèleElasticityDarcyAirfoilPipeDrivAerML-40kLPBF
Vanilla Transformer5.374.386.28
PerceiverIO23.421.51627.1476056.3
GNOT13.316.91035.8911524.3
LNO9.257.6417.88.1014624.7
Transolver s/conv6.4018.68.244.8770.520.4
Transolver avec conv\5.945.503.90\\
FLARE (nôtre)3.385.104.282.8560.818.5

Remarque: Les valeurs sont l'erreur L2 relative (×10⁻³)

Expériences sur géométries avec millions de points

FLARE entraîne avec succès l'ensemble de données DrivAerML avec millions de points sur un seul GPU H100, étant le premier modèle d'agent neuraux basé sur l'attention à traiter des millions de points sans déchargement mémoire ou calcul distribué.

Études d'ablation

  1. Impact du nombre de blocs (B) et du nombre de jetons latents (M):
    • L'augmentation du nombre de blocs réduit continuellement l'erreur relative
    • L'augmentation de M améliore généralement les performances, mais la tendance n'est pas strictement monotone
    • Différents problèmes nécessitent différents rangs
  2. Complexité temporelle et mémoire:
    • FLARE est plus de 200 fois plus rapide que l'attention vanilla
    • L'utilisation mémoire est légèrement supérieure à l'attention vanilla mais bien inférieure à Physics Attention

Analyse spectrale

Analyse les matrices de communication apprises via un algorithme de décomposition en valeurs propres de complexité O(M³+M²N):

  • Les valeurs propres décroissent rapidement dans les blocs précoces, indiquant une compression efficace
  • Les blocs profonds utilisent plus de capacité latente
  • Différentes têtes possèdent différents profils spectraux, validant la conception de projection indépendante par tête

Travaux connexes

Agents PDE neuraux

  • Opérateurs neuraux: FNO, DeepONet et autres apprennent les mappages entre espaces de fonctions de dimension infinie
  • Réseaux de graphes: Exploitent les interactions de voisinage local sur les maillages
  • Architecture Transformer: Permettent l'agrégation de contexte global mais sont limités par la complexité quadratique

Mécanismes d'attention efficaces

  • Linformer: Projette les séquences clé-valeur via des mappages linéaires appris
  • Reformer: Utilise le hachage sensible à la localité
  • Nyströmformer: Utilise la méthode de Nyström pour approximer l'auto-attention
  • LoRA: L'adaptation de faible rang est principalement utilisée pour l'ajustement efficace

Conclusion et discussion

Conclusions principales

  1. FLARE contourne avec succès le goulot d'étranglement de complexité quadratique de l'auto-attention via un mécanisme d'attention de faible rang
  2. Atteint la précision SOTA sur plusieurs benchmarks PDE avec moins de paramètres et une complexité computationnelle inférieure
  3. Réalise pour la première fois l'entraînement de modèles d'agents neuraux basés sur l'attention sur des géométries avec millions de points

Limitations

  1. Dépendance au MLP résiduel profond: Peut introduire un goulot d'étranglement séquentiel et augmenter la latence
  2. Limitation des jetons latents fixes: Le choix de M nécessite un ajustement spécifique au problème
  3. Applicabilité à certains problèmes de rang élevé: Comme dans le problème Darcy où Vanilla Transformer conserve un avantage

Directions futures

  1. Augmenter progressivement le nombre de jetons latents pendant l'entraînement
  2. Concevoir des jetons latents conditionnés temporellement pour la modélisation par diffusion
  3. Développer des variantes décodeur uniquement pour la modélisation autorégressive
  4. Résoudre le goulot d'étranglement séquentiel du MLP résiduel profond

Évaluation approfondie

Avantages

  1. Innovation technique forte:
    • Transforme intelligemment le problème d'acheminement d'attention en décomposition matricielle de faible rang
    • La conception de projection indépendante par tête permet des motifs d'acheminement spécialisés
    • Entièrement compatible avec les noyaux GPU existants
  2. Expérimentation suffisante:
    • Couvre 6 benchmarks PDE différents
    • Études d'ablation détaillées et analyse spectrale
    • Premières expériences à l'échelle des millions de points
  3. Analyse théorique approfondie:
    • Fournit un algorithme de décomposition en valeurs propres O(M³+M²N)
    • Explique mathématiquement l'efficacité de la communication de faible rang
    • Valide les hypothèses de conception via analyse spectrale
  4. Valeur pratique élevée:
    • Publie un nouvel ensemble de données de fabrication additive
    • Code open source, facilitant la reproduction
    • Peut s'intégrer directement dans les architectures Transformer existantes

Insuffisances

  1. Limitations d'applicabilité de la méthode:
    • Efficacité limitée sur les problèmes de rang élevé (comme Darcy)
    • Le choix de M nécessite un ajustement spécifique au problème
    • Le MLP profond peut devenir un nouveau goulot d'étranglement computationnel
  2. Limitations de la configuration expérimentale:
    • Manque de comparaisons avec plus de méthodes récentes
    • L'échelle de certains benchmarks est relativement petite
    • L'universalité pour différents types de problèmes PDE nécessite plus de validation
  3. Analyse théorique insuffisante:
    • Manque d'analyse de convergence
    • Guidance théorique limitée pour le choix optimal de M
    • La validité de l'hypothèse de faible rang pour tous les problèmes PDE nécessite justification supplémentaire

Impact

  1. Contribution académique: Fournit un nouveau paradigme de conception pour les mécanismes d'attention efficaces, particulièrement dans le calcul scientifique
  2. Valeur pratique: Permet aux Transformers de traiter des problèmes géométriques à grande échelle, promouvant le développement de l'IA pour la science
  3. Reproductibilité: Code open source, configuration expérimentale détaillée, facilitant les recherches ultérieures

Scénarios d'application

  • Résolution PDE sur maillages non structurés à grande échelle
  • Traitement de nuages de points et apprentissage géométrique profond
  • Tâches de modélisation de séquences nécessitant communication globale avec ressources computationnelles limitées
  • Applications de modélisation d'agents en calcul scientifique

Références

L'article cite des travaux importants dans les domaines connexes du Transformer, des opérateurs neuraux et des mécanismes d'attention efficaces, fournissant une base théorique solide et des benchmarks de comparaison.


Évaluation globale: Cet article est une recherche de haute qualité proposant une solution innovante pour résoudre le problème de scalabilité du Transformer. La méthode FLARE possède non seulement une explication élégante de décomposition de faible rang en théorie, mais démontre également d'excellentes performances en pratique. La conception expérimentale est suffisante, l'analyse théorique est approfondie, et elle a une importance significative pour promouvoir l'apprentissage géométrique profond à grande échelle et le calcul scientifique.