State-Space Models for Tabular Prior-Data Fitted Networks
Koch, Wever, Raisch et al.
Recent advancements in foundation models for tabular data, such as TabPFN, demonstrated that pretrained Transformer architectures can approximate Bayesian inference with high predictive performance. However, Transformers suffer from quadratic complexity with respect to sequence length, motivating the exploration of more efficient sequence models. In this work, we investigate the potential of using Hydra, a bidirectional linear-time structured state space model (SSM), as an alternative to Transformers in TabPFN. A key challenge lies in SSM's inherent sensitivity to the order of input tokens - an undesirable property for tabular datasets where the row order is semantically meaningless. We investigate to what extent a bidirectional approach can preserve efficiency and enable symmetric context aggregation. Our experiments show that this approach reduces the order-dependence, achieving predictive performance competitive to the original TabPFN model.
academic
Modèles d'Espace d'État pour Réseaux de Données Tabulaires Pré-ajustés par Données Antérieures
Les avancées récentes dans les modèles fondamentaux pour données tabulaires, tels que TabPFN, ont démontré que les architectures Transformer pré-entraînées peuvent approximer l'inférence bayésienne avec des performances prédictives élevées. Cependant, les Transformers souffrent d'une complexité quadratique par rapport à la longueur de la séquence, motivant l'exploration de modèles de séquence plus efficaces. Dans ce travail, nous examinons le potentiel d'utiliser Hydra, un modèle d'espace d'état structuré bidirectionnel à temps linéaire (SSM), comme alternative aux Transformers dans TabPFN. Un défi clé réside dans la sensibilité inhérente des SSM à l'ordre des jetons d'entrée - une propriété indésirable pour les ensembles de données tabulaires où l'ordre des lignes est sémantiquement insignifiant. Nous examinons dans quelle mesure une approche bidirectionnelle peut préserver l'efficacité et permettre l'agrégation symétrique du contexte. Nos expériences montrent que cette approche réduit la dépendance à l'ordre, atteignant des performances prédictives compétitives par rapport au modèle TabPFN original.
Problème à résoudre: Cette recherche aborde le problème d'efficacité computationnelle de l'architecture Transformer dans les modèles fondamentaux pour données tabulaires, en particulier sa complexité O(n²) qui limite la scalabilité sur les grands ensembles de données.
Importance du problème: TabPFN en tant que modèle fondamental pour données tabulaires a démontré des performances exceptionnelles, capable d'approximer l'inférence bayésienne en millisecondes, mais son architecture basée sur Transformer fait face à des goulots d'étranglement mémoire et computationnels lors du traitement de données à grande échelle.
Limitations des approches existantes:
Le mécanisme d'auto-attention du Transformer possède une complexité quadratique
Remplacer directement le Transformer par Mamba introduit une sensibilité à l'ordre de la séquence d'entrée
L'ordre des lignes dans les données tabulaires est sémantiquement insignifiant, ce qui entre en conflit avec la conception causale des SSM
Motivation de la recherche: Explorer les modèles d'espace d'état structurés (SSM) comme alternative aux Transformers, en préservant les avantages d'efficacité de la complexité linéaire tout en réduisant la dépendance à l'ordre d'entrée par un mécanisme de traitement bidirectionnel.
Architecture TabPFN basée sur Hydra proposée: Intégration du modèle d'espace d'état structuré bidirectionnel Hydra dans TabPFN, réalisant un traitement des données tabulaires avec complexité temporelle linéaire.
Introduction de la technique de Permutation Répétée du Contexte (RCP): Réduction supplémentaire de la sensibilité des SSM à l'ordre des séquences par permutation aléatoire répétée des entrées et moyenne des résultats prédictifs.
Amélioration significative de la scalabilité: Comparée au TabPFN original, la nouvelle méthode peut traiter des ensembles de données deux ordres de grandeur plus importants (extension de 2¹⁵ à 2¹⁷ lignes).
Maintien de performances prédictives compétitives: Sur l'ensemble de référence OpenML CC-18, la précision du TabPFN basé sur Hydra n'est inférieure que de 1,1% au modèle original.
Entrée: nombre de permutations r, contexte D, échantillon de test xtest
Sortie: valeur de classe prédite
Initialiser liste vide: outputs ← []
pour i = 1 à r faire
Mélanger les lignes de D: Dp ← shuffle(D)
Concaténer xtest à Dp: Din ← Dp ∪ xtest
Prédire: outputs[i] ← PFN.predict(Din)
fin pour
Retourner la moyenne de outputs
Bidirectionnalité résolvant la sensibilité à l'ordre: Comparé au Mamba unidirectionnel, le traitement bidirectionnel d'Hydra peut agréger symétriquement les informations contextuelles, réduisant la dépendance à l'ordre d'entrée.
Complexité linéaire: Réalisation d'une complexité O(n) par multiplication de matrices quasi-séparables, offrant un avantage significatif par rapport à O(n²) du Transformer.
Stratégie RCP: Innovation consistant à réduire davantage la sensibilité à l'ordre par permutations aléatoires répétées et moyenne des résultats, conception personnalisée pour les caractéristiques des données tabulaires.
Analyse théorique insuffisante: Explication théorique insuffisante de pourquoi la bidirectionnalité résout le problème de sensibilité à l'ordre
Échelle expérimentale limitée: Toujours limitée à ensembles de données relativement petits, capacité de traitement à grande échelle insuffisamment démontrée
Comparaisons incomplètes: Manque de comparaisons directes avec autres méthodes de complexité linéaire (comme Linear Attention)
Analyse d'hyperparamètres insuffisante: Optimisation d'hyperparamètres insuffisante en raison des coûts d'entraînement élevés
Classification tabulaire à grande échelle: Particulièrement adapté aux tâches de classification tabulaire nécessitant traitement de nombreux échantillons
Scénarios d'inférence en temps réel: Complexité linéaire appropriée pour applications exigeant vitesse d'inférence stricte
Environnements à ressources limitées: Nécessite moins mémoire et ressources computationnelles comparé au Transformer
Apprentissage peu supervisé: Préserve avantages de TabPFN dans scénarios peu supervisés
Dao et al. (2022) - Technique d'optimisation FlashAttention
Zeng et al. (2024) - Méthode attention linéaire TabFlex
Cet article apporte une contribution précieuse à la résolution du problème de scalabilité des modèles fondamentaux tabulaires. En combinant intelligemment les SSM bidirectionnels et la stratégie de permutation répétée, il équilibre avec succès les exigences d'efficacité et de performance. Bien que présentant certaines insuffisances en innovation théorique, sa valeur pratique et sa signification inspirante pour recherches futures méritent reconnaissance.