State-Space Models for Tabular Prior-Data Fitted Networks
Koch, Wever, Raisch et al.
Recent advancements in foundation models for tabular data, such as TabPFN, demonstrated that pretrained Transformer architectures can approximate Bayesian inference with high predictive performance. However, Transformers suffer from quadratic complexity with respect to sequence length, motivating the exploration of more efficient sequence models. In this work, we investigate the potential of using Hydra, a bidirectional linear-time structured state space model (SSM), as an alternative to Transformers in TabPFN. A key challenge lies in SSM's inherent sensitivity to the order of input tokens - an undesirable property for tabular datasets where the row order is semantically meaningless. We investigate to what extent a bidirectional approach can preserve efficiency and enable symmetric context aggregation. Our experiments show that this approach reduces the order-dependence, achieving predictive performance competitive to the original TabPFN model.
academic
State-Space Models für tabellarische Prior-Data Fitted Networks
Jüngste Fortschritte bei Foundation Models für tabellarische Daten, wie TabPFN, haben gezeigt, dass vortrainierte Transformer-Architekturen bayesianische Inferenz mit hoher Vorhersageleistung approximieren können. Allerdings leiden Transformer unter quadratischer Komplexität in Bezug auf die Sequenzlänge, was die Erforschung effizienterer Sequenzmodelle motiviert. In dieser Arbeit untersuchen wir das Potenzial von Hydra, einem bidirektionalen linearen strukturierten State-Space-Modell (SSM), als Alternative zu Transformern in TabPFN. Eine Schlüsselherausforderung liegt in der inhärenten Empfindlichkeit von SSMs gegenüber der Reihenfolge von Eingabe-Token – eine unerwünschte Eigenschaft für tabellarische Datensätze, bei denen die Zeilenreihenfolge semantisch bedeutungslos ist. Wir untersuchen, inwieweit ein bidirektionaler Ansatz Effizienz bewahren und symmetrische Kontextaggregation ermöglichen kann. Unsere Experimente zeigen, dass dieser Ansatz die Reihenfolgeabhängigkeit reduziert und eine Vorhersageleistung erreicht, die mit dem ursprünglichen TabPFN-Modell konkurrenzfähig ist.
Zu lösende Probleme: Diese Forschung befasst sich mit dem Rechenkomplexitätsproblem der Transformer-Architektur in Foundation Models für tabellarische Daten, insbesondere mit ihrer O(n²)-Komplexität, die die Skalierbarkeit bei großen Datensätzen einschränkt.
Bedeutung des Problems: TabPFN als Foundation Model für tabellarische Daten zeigt hervorragende Leistung und kann bayesianische Inferenz-Approximation im Millisekundenbereich durchführen, aber seine Transformer-basierte Architektur sieht sich bei der Verarbeitung großer Datenmengen mit Speicher- und Rechenbeschränkungen konfrontiert.
Einschränkungen bestehender Methoden:
Der Self-Attention-Mechanismus von Transformern hat quadratische Komplexität
Der direkte Austausch von Transformer durch Mamba führt zu Empfindlichkeit gegenüber der Eingabesequenzreihenfolge
Die Zeilenreihenfolge in tabellarischen Daten ist semantisch bedeutungslos, was mit dem kausalen Design von SSMs kollidiert
Forschungsmotivation: Erforschung strukturierter State-Space-Modelle (SSM) als Alternative zu Transformern, um sowohl die Effizienzvorteile der linearen Komplexität zu bewahren als auch durch bidirektionale Verarbeitungsmechanismen die Abhängigkeit von der Eingabereihenfolge zu reduzieren.
Vorschlag einer auf Hydra basierenden TabPFN-Architektur: Integration des bidirektionalen strukturierten State-Space-Modells Hydra in TabPFN zur Realisierung linearer Zeitkomplexität bei der Verarbeitung tabellarischer Daten.
Einführung der Repeated Context Permutation (RCP) Technik: Weitere Reduzierung der SSM-Empfindlichkeit gegenüber der Sequenzreihenfolge durch mehrfaches zufälliges Permutieren von Eingaben und Durchschnittsbildung der Vorhersageergebnisse.
Realisierung signifikanter Skalierbarkeitssteigerungen: Im Vergleich zum ursprünglichen TabPFN kann die neue Methode zwei Größenordnungen größere Datensätze verarbeiten (Erweiterung von 2¹⁵ auf 2¹⁷ Zeilen).
Beibehaltung konkurrenzfähiger Vorhersageleistung: Bei der OpenML CC-18 Benchmark-Suite liegt die Genauigkeit des Hydra-basierten TabPFN nur 1,1% unter dem ursprünglichen Modell.
Eingabe: Permutationszahl r, Kontext D, Testprobe xtest
Ausgabe: Vorhergesagte Klassenwerte
Initialisiere leere Liste: outputs ← []
for i = 1 to r do
Permutiere Zeilen von D: Dp ← shuffle(D)
Verkette xtest mit Dp: Din ← Dp ∪ xtest
Vorhersage: outputs[i] ← PFN.predict(Din)
end for
Rückgabe Durchschnitt von outputs
Bidirektionalität löst Reihenfolgeempfindlichkeit: Im Vergleich zum unidirektionalen Mamba kann Hydra durch bidirektionale Verarbeitung Kontextinformationen symmetrisch aggregieren und die Abhängigkeit von der Eingabereihenfolge reduzieren.
Lineare Komplexität: Realisierung von O(n)-Komplexität durch quasi-separierbare Matrixmultiplikation, mit signifikantem Vorteil gegenüber der O(n²)-Komplexität von Transformern.
RCP-Strategie: Innovativer Ansatz zur weiteren Reduzierung der Reihenfolgeempfindlichkeit durch mehrfache zufällige Permutation und Ergebnisdurchschnittsbildung, ein maßgeschneidertes Design für die Charakteristiken tabellarischer Daten.
Dao et al. (2022) - FlashAttention-Optimierungstechnik
Zeng et al. (2024) - TabFlex lineare Aufmerksamkeitsmethode
Dieses Papier leistet einen wertvollen Beitrag zur Lösung des Skalierbarkeitsproblems tabellarischer Foundation Models. Durch geschickte Kombination bidirektionaler SSM und wiederholter Permutationsstrategien wird erfolgreich ein Ausgleich zwischen Effizienz und Leistung erreicht. Obwohl es in Bezug auf theoretische Innovation Mängel aufweist, sind sein praktischer Wert und seine Inspirationskraft für zukünftige Forschung bemerkenswert.