The quadratic complexity of self-attention limits its applicability and scalability on large unstructured meshes. We introduce Fast Low-rank Attention Routing Engine (FLARE), a linear complexity self-attention mechanism that routes attention through fixed-length latent sequences. Each attention head performs global communication among $N$ tokens by projecting the input sequence onto a fixed length latent sequence of $M \ll N$ tokens using learnable query tokens. By routing attention through a bottleneck sequence, FLARE learns a low-rank form of attention that can be applied at $O(NM)$ cost. FLARE not only scales to unprecedented problem sizes, but also delivers superior accuracy compared to state-of-the-art neural PDE surrogates across diverse benchmarks. We also release a new additive manufacturing dataset to spur further research. Our code is available at https://github.com/vpuri3/FLARE.py.
La complexité quadratique des mécanismes d'auto-attention traditionnels limite leur applicabilité et leur scalabilité sur les maillages non structurés à grande échelle. Cet article propose FLARE (Fast Low-rank Attention Routing Engine), un mécanisme d'auto-attention à complexité linéaire qui achemine l'attention via une séquence latente de longueur fixe. Chaque tête d'attention projette la séquence d'entrée sur une séquence latente de longueur fixe M≪N en utilisant des jetons de requête apprenables, réalisant ainsi une communication globale entre N jetons. En acheminant l'attention via une séquence goulot d'étranglement, FLARE apprend une forme d'attention de faible rang qui peut être appliquée au coût O(NM). FLARE non seulement s'étend à des échelles de problèmes sans précédent, mais fournit également une meilleure précision par rapport aux modèles d'agents PDE neuraux de pointe sur plusieurs benchmarks.
Problème fondamental: Le mécanisme d'auto-attention du Transformer traditionnel possède une complexité temporelle et mémoire O(N²), ce qui limite sévèrement son application sur les maillages non structurés à grande échelle (tels que les nuages de points et les maillages dans les simulations physiques).
Importance de l'application: Dans la modélisation d'agents PDE (équations aux dérivées partielles), chaque point dans un nuage de points 3D est considéré comme un jeton contenant des caractéristiques géométriques et physiques (telles que les coordonnées, les vecteurs normaux, les propriétés matérielles). La simulation de systèmes physiques haute fidélité est trop coûteuse; les modèles d'agents d'apprentissage automatique offrent une alternative d'approximation rapide.
Limitations des méthodes existantes:
PerceiverIO: Effectue uniquement un encodage et un décodage uniques; le goulot d'étranglement potentiel peut limiter la précision
Transolver: Partage les poids de projection entre les têtes, ne peut pas exploiter les noyaux GPU existants pour l'attention au produit scalaire échelonné
LNO: Applique uniquement une projection unique, manque de capacité de modèle profond
Motivation de la recherche: Développer un mécanisme d'attention capable de maintenir la communication globale tout en possédant une complexité linéaire, permettant aux Transformers de traiter des géométries avec des millions de points.
Mélange de jetons à complexité linéaire: Propose le mécanisme d'auto-attention FLARE, réalisant une complexité linéaire en remplaçant l'auto-attention complète par une projection et reconstruction de faible rang.
Précision supérieure: Sur plusieurs benchmarks PDE, FLARE atteint une précision de prédiction supérieure aux modèles d'agents neuraux leaders avec moins de paramètres et une complexité computationnelle inférieure.
Scalabilité sans précédent: FLARE est entièrement construit sur des primitives d'attention fusionnées standard, assurant une utilisation GPU élevée et supportant l'entraînement bout en bout sur des maillages non structurés avec des millions de points.
Nouveau benchmark de données: Publie un ensemble de données haute résolution à grande échelle pour la fabrication additive métallique destiné à la recherche sur la prédiction de déplacement résiduel.
Étant donné une séquence d'entrée X ∈ R^(N×C), où N est le nombre de jetons et C est la dimension des caractéristiques, FLARE vise à apprendre un mécanisme d'attention à complexité linéaire réalisant une communication efficace entre jetons globale.
Projection indépendante par tête: Contrairement à Transolver qui partage les poids de projection, FLARE assigne à chaque tête une tranche différente de jetons latents, permettant à chaque tête d'apprendre des relations d'attention indépendantes.
MLP résiduel profond: Utilise des réseaux résiduels profonds pour la projection clé/valeur, apprenant des interactions de caractéristiques d'ordre supérieur par rapport aux couches linéaires simples.
Conception symétrique encodage-décodage: La symétrie des opérations d'encodage et de décodage favorise un flux d'information stable.
Compatibilité avec les noyaux fusionnés: Entièrement basé sur les opérations SDPA standard, peut exploiter les algorithmes d'optimisation tels que Flash Attention.
FLARE entraîne avec succès l'ensemble de données DrivAerML avec millions de points sur un seul GPU H100, étant le premier modèle d'agent neuraux basé sur l'attention à traiter des millions de points sans déchargement mémoire ou calcul distribué.
Contribution académique: Fournit un nouveau paradigme de conception pour les mécanismes d'attention efficaces, particulièrement dans le calcul scientifique
Valeur pratique: Permet aux Transformers de traiter des problèmes géométriques à grande échelle, promouvant le développement de l'IA pour la science
Reproductibilité: Code open source, configuration expérimentale détaillée, facilitant les recherches ultérieures
L'article cite des travaux importants dans les domaines connexes du Transformer, des opérateurs neuraux et des mécanismes d'attention efficaces, fournissant une base théorique solide et des benchmarks de comparaison.
Évaluation globale: Cet article est une recherche de haute qualité proposant une solution innovante pour résoudre le problème de scalabilité du Transformer. La méthode FLARE possède non seulement une explication élégante de décomposition de faible rang en théorie, mais démontre également d'excellentes performances en pratique. La conception expérimentale est suffisante, l'analyse théorique est approfondie, et elle a une importance significative pour promouvoir l'apprentissage géométrique profond à grande échelle et le calcul scientifique.