2025-11-11T13:04:09.550712

TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification

Dissanayake, Dutta
Transformer-based models have shown promising performance on tabular data compared to their classical counterparts such as neural networks and Gradient Boosted Decision Trees (GBDTs) in scenarios with limited training data. They utilize their pre-trained knowledge to adapt to new domains, achieving commendable performance with only a few training examples, also called the few-shot regime. However, the performance gain in the few-shot regime comes at the expense of significantly increased complexity and number of parameters. To circumvent this trade-off, we introduce TabDistill, a new strategy to distill the pre-trained knowledge in complex transformer-based models into simpler neural networks for effectively classifying tabular data. Our framework yields the best of both worlds: being parameter-efficient while performing well with limited training data. The distilled neural networks surpass classical baselines such as regular neural networks, XGBoost and logistic regression under equal training data, and in some cases, even the original transformer-based models that they were distilled from.
academic

TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification

基本信息

  • 论文ID: 2511.05704
  • 标题: TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification
  • 作者: Pasan Dissanayake, Sanghamitra Dutta (University of Maryland, College Park)
  • 分类: cs.LG cs.AI cs.CL
  • 发表时间: 2025年11月7日 (arXiv预印本)
  • 论文链接: https://arxiv.org/abs/2511.05704

摘要

Transformer-based models have shown promising performance on tabular data compared to their classical counterparts such as neural networks and Gradient Boosted Decision Trees (GBDTs) in scenarios with limited training data. They utilize their pre-trained knowledge to adapt to new domains, achieving commendable performance with only a few training examples, also called the few-shot regime. However, the performance gain in the few-shot regime comes at the expense of significantly increased complexity and number of parameters. To circumvent this trade-off, we introduce TabDistill, a new strategy to distill the pre-trained knowledge in complex transformer-based models into simpler neural networks for effectively classifying tabular data. Our framework yields the best of both worlds: being parameter-efficient while performing well with limited training data. The distilled neural networks surpass classical baselines such as regular neural networks, XGBoost and logistic regression under equal training data, and in some cases, even the original transformer-based models that they were distilled from.

研究背景与动机

问题定义

该研究要解决表格数据分类中的一个核心矛盾:在少样本场景下,基于Transformer的模型虽然性能优异,但参数量巨大,计算复杂度高,难以在实际应用中部署。

问题重要性

  1. 实际应用需求:在金融、医疗、制造等高风险领域,标注数据稀缺是常见问题,如罕见疾病诊断、百年一遇的自然现象预测等
  2. 数据标注成本:金融应用中数据标注昂贵,存在主观性、错误标注、缺乏共识等问题
  3. 部署约束:实际应用需要参数高效且可扩展的模型,以适应不同基础设施水平

现有方法局限性

  1. 传统方法:XGBoost、CatBoost、LightGBM等在充足数据下表现优异,但在少样本场景下性能显著下降
  2. Transformer方法:TabPFN、TabLLM等在少样本场景下表现出色,但参数量达到百万甚至十亿级别,推理成本高昂
  3. 效率与性能权衡:缺乏既保持少样本性能又具备参数效率的解决方案

研究动机

作者提出核心问题:"能否实现两全其美,即既保持参数效率又在有限训练数据下表现良好?"

核心贡献

  1. 提出TabDistill框架:一种将Transformer模型知识蒸馏到神经网络的新策略,实现参数高效的表格数据分类
  2. 双模型实例化:基于TabPFN(~11M参数)和BigScience T0pp(~11B参数)实现框架,蒸馏为约1000参数的MLP
  3. 实验验证:在5个表格数据集上验证,蒸馏后的MLP超越经典基线,某些情况下甚至超过原始Transformer模型
  4. 创新训练策略:引入基于排列的训练技术,避免在极小训练集上过拟合

方法详解

任务定义

给定小规模表格数据集 DN={(xn,yn),xnX,yn{0,1},n=1,...,N}D_N = \{(x_n, y_n), x_n \in X, y_n \in \{0,1\}, n=1,...,N\},其中N10N \sim 10,目标是利用预训练Transformer模型ff的知识生成简单MLP hθ(x):X{0,1}h_\theta(x): X \to \{0,1\}

模型架构

整体框架

TabDistill包含两个阶段:

  • 阶段1:微调基础Transformer模型以生成优质MLP
  • 阶段2:可选的MLP额外微调

核心组件

  1. 基础模型分解
    • 编码器:fE(s):SZf_E(s): S \to Z
    • 解码器:fD(z):Z{0,1}f_D(z): Z \to \{0,1\}
  2. MLP架构
    h_θ(x) = ReLU(W_R ReLU(···ReLU(W_2 ReLU(W_1 x + b_1) + b_2)···) + b_R)
    

    其中R为层数,L为隐藏层宽度
  3. 线性映射
    m_η(z) = LayerNorm(Az + b)
    

    其中ARdim(Θ)×dim(Z)A \in R^{dim(Θ)×dim(Z)}η=(A,b)η = (A,b)

训练流程

阶段1损失函数

L(η; D_N) = Σ[y_n log(σ(h_θ(x_n))[[1]]) + (1-y_n) log(σ(h_θ(x_n))[[0]])]

其中θ=mη(fE(g(DN)))θ = m_η(f_E(g(D_N)))

技术创新点

  1. 超网络思想:借鉴计算机视觉领域经验,将Transformer用作生成神经网络权重的超网络
  2. 排列增强:每个训练epoch随机排列特征顺序,避免过拟合
  3. 参数高效微调:仅微调线性映射参数ηη,保持基础模型参数不变
  4. 双阶段设计:先蒸馏后微调,充分利用预训练知识

具体实例化

TabDistill + TabPFN

  • 直接使用表格数据,g(x)=xg(x) = x(恒等变换)
  • 编码器输出维度:192N192N
  • 映射矩阵维度:dim(Θ)×192Ndim(Θ) × 192N

TabDistill + T0pp

  • 使用文本序列化:"The <column name> is <value>"
  • 编码器输出维度:4096
  • 映射矩阵维度:dim(Θ)×4096dim(Θ) × 4096

实验设置

数据集

使用5个公开表格数据集:

  1. Bank (UCI Bank Marketing):预测客户是否订阅定期存款
  2. Blood (UCI Blood Transfusion):预测是否会献血
  3. Calhousing (California Housing):预测房屋街区是否有价值
  4. Heart (UCI Heart Disease):预测是否患心脏病
  5. Income (Census Income):预测年收入是否超过50K

评价指标

使用ROC-AUC作为主要评价指标,考虑少样本场景下的分类性能。

对比方法

  1. 经典基线:逻辑回归、XGBoost、独立训练的MLP
  2. 基础模型:TabPFN、T0pp (TabLLM)
  3. 蒸馏模型:TabDistill + TabPFN、TabDistill + T0pp

实现细节

  • MLP架构:4层,每层10个神经元(约1000参数)
  • 训练设置:阶段1微调300轮,阶段2额外100轮
  • 超参数优化:使用Weights & Biases进行网格搜索
  • 样本规模:N ∈ {4, 8, 16, 32, 64}

实验结果

主要结果

根据Table 1的ROC-AUC结果:

极少样本场景 (N=4)

  • TabDistill + TabPFN在Bank数据集上达到0.72,显著超过所有经典基线
  • TabDistill + T0pp在多个数据集上表现优异,如Calhousing (0.67) 和Income (0.70)

性能趋势

  1. 随样本增加性能提升:所有方法在N增大时性能普遍改善
  2. 基线方法差异:没有单一经典方法在所有数据集上普遍最优
  3. 模型选择差异:TabDistill + TabPFN整体优于TabDistill + T0pp,但在Income数据集上相反

与基础模型对比

Table 3显示了令人惊讶的结果:

  • 在某些情况下,蒸馏后的MLP超过了原始Transformer模型
  • 例如Bank数据集N=4时:TabDistill + TabPFN (0.72) > TabPFN (0.62)
  • 这表明蒸馏过程不仅压缩了模型,还可能提升了性能

消融实验

模型复杂度影响 (Table 2)

  • 测试不同层数R对性能的影响
  • 结果显示:复杂度超过某个阈值后性能下降
  • 4层架构在大多数情况下表现最佳

特征归因分析 (Figure 3)

使用SHAP分析特征重要性:

  • 蒸馏模型与经典基线在特征重要性上保持一致
  • 即使在特征排列后,模型仍能正确识别重要特征
  • 证明基础模型正确学习了MLP权重与特征顺序的关联

实验发现

  1. 蒸馏效果显著:在极少样本场景下,蒸馏模型明显优于经典方法
  2. 参数效率:从百万/十亿参数压缩到千级参数,效率提升巨大
  3. 知识传递有效:预训练知识成功转移到简单MLP中
  4. 鲁棒性良好:排列增强策略有效防止过拟合

相关工作

表格数据经典算法

  • 传统优势:XGBoost、LightGBM、CatBoost长期主导表格数据领域
  • 少样本局限:从零训练的经典模型在少样本场景下性能显著下降

Transformer表格数据应用

  • SAINT:使用注意力机制建模行列交互,引入自监督预训练
  • TabPFN:在大量合成表格数据上预训练,无需额外训练即可预测新任务
  • TabLLM系列:将表格数据序列化为文本,利用LLM进行分类

元学习与超网络

  • 元学习联系:Transformer擅长上下文学习,类似元学习范式
  • 超网络应用:计算机视觉中已有用Transformer生成神经网络权重的工作
  • 本文创新:首次将此思想应用于表格数据领域

知识蒸馏

  • 传统蒸馏:通过损失函数对齐学生模型与教师模型输出
  • 本文差异:直接从Transformer中提取神经网络,无需损失对齐

结论与讨论

主要结论

  1. 有效性验证:TabDistill成功实现了参数效率与少样本性能的平衡
  2. 性能优势:蒸馏后的MLP在多数情况下超越经典基线,部分场景下甚至超过原始Transformer
  3. 实用价值:提供了一种实际可部署的解决方案,满足不同基础设施需求

局限性

作者诚实指出以下不足:

  1. 大样本性能:当训练样本增多时,性能提升有限
  2. 映射函数简单:当前使用简单线性映射,可能限制性能上限
  3. 偏见继承:蒸馏模型可能继承基础模型的偏见
  4. 应用范围:目前仅验证了二分类任务

未来方向

  1. 映射函数改进:探索更复杂的映射函数以提升性能
  2. 应用扩展:扩展到自然语言推理、指令调优等其他少样本任务
  3. 偏见缓解:通过第二阶段MLP微调减轻基础模型偏见
  4. 多任务学习:探索同时处理多个表格任务的可能性

深度评价

优点

  1. 问题针对性强:准确识别并解决了实际应用中的核心矛盾
  2. 方法创新性:首次将超网络思想应用于表格数据蒸馏
  3. 实验设计完整
    • 多数据集验证
    • 充分的基线对比
    • 详细的消融实验
    • 特征归因分析
  4. 结果令人信服:不仅实现了预期目标,还发现了蒸馏模型超越原模型的有趣现象
  5. 实用价值高:提供了可直接应用的解决方案

不足

  1. 理论分析不足:缺乏对为什么蒸馏模型能超越原模型的理论解释
  2. 数据集规模有限:仅在5个相对小规模数据集上验证
  3. 任务类型单一:只考虑了二分类任务,未涉及回归或多分类
  4. 基础模型选择:只测试了两个基础模型,覆盖面有限
  5. 计算成本分析:未详细比较训练和推理的实际计算成本

影响力

  1. 学术贡献
    • 开创了表格数据Transformer蒸馏的新方向
    • 为少样本学习提供了新的解决思路
    • 连接了超网络与知识蒸馏两个研究领域
  2. 实用价值
    • 解决了实际部署中的重要问题
    • 为资源受限环境提供了可行方案
    • 可直接应用于工业场景
  3. 可复现性
    • 提供了详细的实现细节
    • 开源承诺增强了可复现性
    • 实验设置清晰可重复

适用场景

  1. 资源受限环境:移动设备、边缘计算等场景
  2. 少样本应用:医疗诊断、金融风控、质量检测等数据稀缺领域
  3. 实时推理需求:需要快速响应的在线服务
  4. 模型解释性要求:相比复杂Transformer,简单MLP更易解释

参考文献

论文引用了丰富的相关工作,主要包括:

  • 表格数据经典方法:XGBoost、LightGBM、CatBoost等
  • Transformer表格应用:TabPFN、SAINT、TabLLM系列
  • 知识蒸馏:Hinton等的经典工作
  • 超网络:计算机视觉中的相关应用
  • 元学习:Transformer上下文学习相关研究

总体评价:这是一篇高质量的研究论文,针对实际问题提出了创新解决方案,实验验证充分,具有重要的学术价值和实用价值。虽然存在一些局限性,但为相关领域的发展做出了重要贡献。