State-Space Models for Tabular Prior-Data Fitted Networks
Koch, Wever, Raisch et al.
Recent advancements in foundation models for tabular data, such as TabPFN, demonstrated that pretrained Transformer architectures can approximate Bayesian inference with high predictive performance. However, Transformers suffer from quadratic complexity with respect to sequence length, motivating the exploration of more efficient sequence models. In this work, we investigate the potential of using Hydra, a bidirectional linear-time structured state space model (SSM), as an alternative to Transformers in TabPFN. A key challenge lies in SSM's inherent sensitivity to the order of input tokens - an undesirable property for tabular datasets where the row order is semantically meaningless. We investigate to what extent a bidirectional approach can preserve efficiency and enable symmetric context aggregation. Our experiments show that this approach reduces the order-dependence, achieving predictive performance competitive to the original TabPFN model.
academic
State-Space Models for Tabular Prior-Data Fitted Networks
Recent advancements in foundation models for tabular data, such as TabPFN, demonstrated that pretrained Transformer architectures can approximate Bayesian inference with high predictive performance. However, Transformers suffer from quadratic complexity with respect to sequence length, motivating the exploration of more efficient sequence models. In this work, we investigate the potential of using Hydra, a bidirectional linear-time structured state space model (SSM), as an alternative to Transformers in TabPFN. A key challenge lies in SSM's inherent sensitivity to the order of input tokens - an undesirable property for tabular datasets where the row order is semantically meaningless. We investigate to what extent a bidirectional approach can preserve efficiency and enable symmetric context aggregation. Our experiments show that this approach reduces the order-dependence, achieving predictive performance competitive to the original TabPFN model.
输入: 排列次数r, 上下文D, 测试样本xtest
输出: 预测的类别值
初始化空列表: outputs ← []
for i = 1 to r do
打乱D的行: Dp ← shuffle(D)
拼接xtest到Dp: Din ← Dp ∪ xtest
预测: outputs[i] ← PFN.predict(Din)
end for
返回 outputs的平均值