State-Space Models for Tabular Prior-Data Fitted Networks
Koch, Wever, Raisch et al.
Recent advancements in foundation models for tabular data, such as TabPFN, demonstrated that pretrained Transformer architectures can approximate Bayesian inference with high predictive performance. However, Transformers suffer from quadratic complexity with respect to sequence length, motivating the exploration of more efficient sequence models. In this work, we investigate the potential of using Hydra, a bidirectional linear-time structured state space model (SSM), as an alternative to Transformers in TabPFN. A key challenge lies in SSM's inherent sensitivity to the order of input tokens - an undesirable property for tabular datasets where the row order is semantically meaningless. We investigate to what extent a bidirectional approach can preserve efficiency and enable symmetric context aggregation. Our experiments show that this approach reduces the order-dependence, achieving predictive performance competitive to the original TabPFN model.
academic
State-Space Models for Tabular Prior-Data Fitted Networks
Recent advancements in foundation models for tabular data, such as TabPFN, demonstrated that pretrained Transformer architectures can approximate Bayesian inference with high predictive performance. However, Transformers suffer from quadratic complexity with respect to sequence length, motivating the exploration of more efficient sequence models. In this work, we investigate the potential of using Hydra, a bidirectional linear-time structured state space model (SSM), as an alternative to Transformers in TabPFN. A key challenge lies in SSM's inherent sensitivity to the order of input tokens - an undesirable property for tabular datasets where the row order is semantically meaningless. We investigate to what extent a bidirectional approach can preserve efficiency and enable symmetric context aggregation. Our experiments show that this approach reduces the order-dependence, achieving predictive performance competitive to the original TabPFN model.
Problem to be Addressed: This research targets the computational efficiency problem of Transformer architectures in foundation models for tabular data, particularly the O(n²) complexity that limits scalability on large-scale datasets.
Problem Significance: TabPFN as a foundation model for tabular data demonstrates excellent performance, enabling Bayesian inference approximation within milliseconds. However, its Transformer-based architecture faces memory and computational bottlenecks when processing large-scale data.
Limitations of Existing Methods:
Transformer self-attention mechanisms have quadratic complexity
Direct replacement of Transformer with Mamba introduces sensitivity to input sequence order
Row order in tabular data is semantically meaningless, conflicting with SSM's causal design
Research Motivation: To explore structured state space models (SSMs) as alternatives to Transformers, maintaining the efficiency advantages of linear complexity while reducing input order dependence through bidirectional processing mechanisms.
Proposed Hydra-based TabPFN Architecture: Integrated bidirectional structured state space model Hydra into TabPFN, achieving linear-time complexity for tabular data processing.
Introduced Repeated Context Permutation (RCP) Technique: Further reduced SSM's sensitivity to sequence order by performing multiple random permutations of inputs and averaging prediction results.
Achieved Significant Scalability Improvements: Compared to original TabPFN, the new method can handle two orders of magnitude larger datasets (extending from 2¹⁵ to 2¹⁷ rows).
Maintained Competitive Predictive Performance: In OpenML CC-18 benchmark testing, Hydra-based TabPFN's accuracy is only 1.1% lower than the original model.
Input: Number of permutations r, context D, test sample xtest
Output: Predicted class value
Initialize empty list: outputs ← []
for i = 1 to r do
Shuffle rows of D: Dp ← shuffle(D)
Concatenate xtest to Dp: Din ← Dp ∪ xtest
Predict: outputs[i] ← PFN.predict(Din)
end for
Return average of outputs
Bidirectionality Addresses Order Sensitivity: Compared to unidirectional Mamba, Hydra's bidirectional processing enables symmetric context aggregation, reducing input order dependence.
Linear Complexity: Achieves O(n) complexity through quasi-separable matrix multiplication, providing significant advantages over Transformer's O(n²).
RCP Strategy: Innovatively reduces order sensitivity through multiple random permutations and result averaging, a customized design for tabular data characteristics.
Dao et al. (2022) - FlashAttention optimization techniques
Zeng et al. (2024) - TabFlex linear attention method
This paper makes valuable contributions to addressing scalability issues in tabular foundation models. By cleverly combining bidirectional SSMs with repeated permutation strategies, it successfully balances efficiency and performance requirements. While somewhat limited in theoretical innovation, its practical value and inspirational significance for future research merit recognition.