State-Space Models for Tabular Prior-Data Fitted Networks
Koch, Wever, Raisch et al.
Recent advancements in foundation models for tabular data, such as TabPFN, demonstrated that pretrained Transformer architectures can approximate Bayesian inference with high predictive performance. However, Transformers suffer from quadratic complexity with respect to sequence length, motivating the exploration of more efficient sequence models. In this work, we investigate the potential of using Hydra, a bidirectional linear-time structured state space model (SSM), as an alternative to Transformers in TabPFN. A key challenge lies in SSM's inherent sensitivity to the order of input tokens - an undesirable property for tabular datasets where the row order is semantically meaningless. We investigate to what extent a bidirectional approach can preserve efficiency and enable symmetric context aggregation. Our experiments show that this approach reduces the order-dependence, achieving predictive performance competitive to the original TabPFN model.
academic
Modelli State-Space per Reti Tabellari Prior-Data Fitted
I recenti progressi nei modelli fondamentali per dati tabulari, come TabPFN, hanno dimostrato che le architetture Transformer pre-addestrate possono approssimare l'inferenza bayesiana con elevate prestazioni predittive. Tuttavia, i Transformer soffrono di complessità quadratica rispetto alla lunghezza della sequenza, motivando l'esplorazione di modelli di sequenza più efficienti. In questo lavoro, investighiamo il potenziale dell'utilizzo di Hydra, un modello bidirezionale di spazio degli stati strutturato a tempo lineare (SSM), come alternativa ai Transformer in TabPFN. Una sfida chiave risiede nella sensibilità intrinseca dell'SSM all'ordine dei token di input - una proprietà indesiderabile per i dataset tabulari dove l'ordine delle righe è semanticamente insignificante. Investighiamo in che misura un approccio bidirezionale possa preservare l'efficienza e abilitare l'aggregazione simmetrica del contesto. I nostri esperimenti mostrano che questo approccio riduce la dipendenza dall'ordine, raggiungendo prestazioni predittive competitive con il modello TabPFN originale.
Problema da risolvere: Questa ricerca affronta il problema dell'efficienza computazionale dell'architettura Transformer nei modelli fondamentali per dati tabulari, in particolare la sua complessità O(n²) che limita la scalabilità su dataset di grandi dimensioni.
Importanza del problema: TabPFN come modello fondamentale per dati tabulari ha dimostrato prestazioni eccellenti, completando l'approssimazione dell'inferenza bayesiana in millisecondi, ma la sua architettura basata su Transformer affronta colli di bottiglia di memoria e calcolo nel trattamento di dati su larga scala.
Limitazioni dei metodi esistenti:
Il meccanismo di auto-attenzione del Transformer ha complessità quadratica
La sostituzione diretta di Mamba al Transformer introduce sensibilità all'ordine della sequenza di input
L'ordine delle righe nei dati tabulari è semanticamente insignificante, entrando in conflitto con il design causale dell'SSM
Motivazione della ricerca: Esplorare i modelli di spazio degli stati strutturati (SSM) come alternativa ai Transformer, mantenendo i vantaggi di efficienza della complessità lineare, riducendo al contempo la dipendenza dall'ordine di input attraverso un meccanismo di elaborazione bidirezionale.
Proposta dell'architettura TabPFN basata su Hydra: Integrazione del modello di spazio degli stati strutturato bidirezionale Hydra in TabPFN, realizzando l'elaborazione dei dati tabulari con complessità temporale lineare.
Introduzione della tecnica di Permutazione Ripetuta del Contesto (RCP): Riduzione ulteriore della sensibilità dell'SSM all'ordine della sequenza attraverso permutazioni casuali multiple dell'input e media dei risultati predittivi.
Realizzazione di un significativo miglioramento della scalabilità: Rispetto a TabPFN originale, il nuovo metodo può elaborare dataset due ordini di grandezza più grandi (da 2¹⁵ a 2¹⁷ righe).
Mantenimento di prestazioni predittive competitive: Nel benchmark OpenML CC-18, l'accuratezza di Hydra-based TabPFN è inferiore solo dell'1,1% rispetto al modello originale.
Input: numero di permutazioni r, contesto D, campione di test xtest
Output: valore di classe predetto
Inizializza lista vuota: outputs ← []
for i = 1 to r do
Mescola righe di D: Dp ← shuffle(D)
Concatena xtest a Dp: Din ← Dp ∪ xtest
Predizione: outputs[i] ← PFN.predict(Din)
end for
Restituisci media di outputs
Bidirezionalità per risolvere la sensibilità all'ordine: Rispetto a Mamba unidirezionale, l'elaborazione bidirezionale di Hydra può aggregare simmetricamente le informazioni di contesto, riducendo la dipendenza dall'ordine di input.
Complessità lineare: Realizzazione della complessità O(n) attraverso moltiplicatori di matrici quasi-separabili, con vantaggi significativi rispetto a O(n²) del Transformer.
Strategia RCP: Innovativa riduzione della sensibilità all'ordine attraverso permutazioni casuali multiple e media dei risultati, un design personalizzato per le caratteristiche dei dati tabulari.
Grado di innovazione limitato: Principalmente combinazione e applicazione di tecniche esistenti, mancanza di innovazione fondamentale
Analisi teorica insufficiente: Mancanza di spiegazione teorica approfondita del perché la bidirezionalità risolve il problema della sensibilità all'ordine
Scala sperimentale limitata: Ancora limitata a dataset relativamente piccoli, incapace di dimostrare pienamente la capacità di elaborazione su larga scala
Confronto incompleto: Mancanza di confronto diretto con altri metodi di complessità lineare (come Linear Attention)
Analisi degli iperparametri insufficiente: A causa dell'alto costo di addestramento, mancanza di ottimizzazione sufficiente degli iperparametri
Classificazione tabulare su larga scala: Particolarmente adatto per compiti di classificazione tabulare che richiedono l'elaborazione di un gran numero di campioni
Scenari di inferenza in tempo reale: La complessità lineare la rende adatta per applicazioni con requisiti rigorosi sulla velocità di inferenza
Ambienti con risorse limitate: Richiede meno memoria e risorse computazionali rispetto al Transformer
Apprendimento con pochi campioni: Mantiene i vantaggi di TabPFN negli scenari di apprendimento con pochi campioni
Hollmann et al. (2023) - Articolo originale di TabPFN
Gu & Dao (2023) - Architettura Mamba
Hwang et al. (2024) - SSM bidirezionale Hydra
Dao et al. (2022) - Tecnica di ottimizzazione FlashAttention
Zeng et al. (2024) - Metodo di attenzione lineare TabFlex
Questo articolo fornisce un contributo prezioso nella risoluzione del problema della scalabilità dei modelli fondamentali tabulari, combinando abilmente SSM bidirezionale e strategia di permutazione ripetuta, raggiungendo con successo l'equilibrio tra i requisiti di efficienza e prestazioni. Sebbene presenti insufficienze nell'innovazione teorica, il suo valore pratico e il significato ispiratore per la ricerca futura meritano riconoscimento.