2025-11-11T13:04:09.550712

TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification

Dissanayake, Dutta
Transformer-based models have shown promising performance on tabular data compared to their classical counterparts such as neural networks and Gradient Boosted Decision Trees (GBDTs) in scenarios with limited training data. They utilize their pre-trained knowledge to adapt to new domains, achieving commendable performance with only a few training examples, also called the few-shot regime. However, the performance gain in the few-shot regime comes at the expense of significantly increased complexity and number of parameters. To circumvent this trade-off, we introduce TabDistill, a new strategy to distill the pre-trained knowledge in complex transformer-based models into simpler neural networks for effectively classifying tabular data. Our framework yields the best of both worlds: being parameter-efficient while performing well with limited training data. The distilled neural networks surpass classical baselines such as regular neural networks, XGBoost and logistic regression under equal training data, and in some cases, even the original transformer-based models that they were distilled from.
academic

TabDistill: Destilación de Transformers en Redes Neuronales para Clasificación Tabular en Pocos Ejemplos

Información Básica

  • ID del Artículo: 2511.05704
  • Título: TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification
  • Autores: Pasan Dissanayake, Sanghamitra Dutta (Universidad de Maryland, College Park)
  • Clasificación: cs.LG cs.AI cs.CL
  • Fecha de Publicación: 7 de noviembre de 2025 (preimpresión en arXiv)
  • Enlace del Artículo: https://arxiv.org/abs/2511.05704

Resumen

Los modelos basados en Transformers han demostrado un desempeño prometedor en datos tabulares en comparación con sus contrapartes clásicas, como redes neuronales y Árboles de Decisión Potenciados por Gradiente (GBDT), en escenarios con datos de entrenamiento limitados. Utilizan su conocimiento preentrenado para adaptarse a nuevos dominios, logrando un desempeño encomiable con solo unos pocos ejemplos de entrenamiento, también denominado régimen de pocos ejemplos. Sin embargo, la mejora de desempeño en el régimen de pocos ejemplos tiene un costo de complejidad significativamente aumentada y número de parámetros. Para eludir este compromiso, introducimos TabDistill, una nueva estrategia para destilar el conocimiento preentrenado en modelos complejos basados en Transformers en redes neuronales más simples para clasificar efectivamente datos tabulares. Nuestro marco logra lo mejor de ambos mundos: ser eficiente en parámetros mientras se desempeña bien con datos de entrenamiento limitados. Las redes neuronales destiladas superan líneas base clásicas como redes neuronales regulares, XGBoost y regresión logística bajo igualdad de datos de entrenamiento, y en algunos casos, incluso los modelos originales basados en Transformers de los que fueron destilados.

Contexto de Investigación y Motivación

Definición del Problema

Esta investigación aborda una contradicción central en la clasificación de datos tabulares: en escenarios de pocos ejemplos, aunque los modelos basados en Transformers muestran un desempeño excelente, poseen una cantidad masiva de parámetros y alta complejidad computacional, lo que dificulta su despliegue en aplicaciones prácticas.

Importancia del Problema

  1. Necesidades de Aplicación Práctica: En campos de alto riesgo como finanzas, medicina y manufactura, la escasez de datos anotados es un problema común, como el diagnóstico de enfermedades raras y la predicción de fenómenos naturales centenarios
  2. Costo de Anotación de Datos: En aplicaciones financieras, la anotación de datos es costosa, con problemas de subjetividad, anotaciones erróneas y falta de consenso
  3. Restricciones de Despliegue: Las aplicaciones prácticas requieren modelos eficientes en parámetros y escalables para adaptarse a diferentes niveles de infraestructura

Limitaciones de Métodos Existentes

  1. Métodos Tradicionales: XGBoost, CatBoost, LightGBM, etc., funcionan excelentemente con datos suficientes, pero su desempeño disminuye significativamente en escenarios de pocos ejemplos
  2. Métodos Transformer: TabPFN, TabLLM, etc., muestran un desempeño excepcional en escenarios de pocos ejemplos, pero tienen millones o incluso miles de millones de parámetros, con altos costos de inferencia
  3. Compensación Eficiencia-Desempeño: Falta una solución que mantenga el desempeño en pocos ejemplos mientras proporciona eficiencia de parámetros

Motivación de la Investigación

Los autores plantean la pregunta central: "¿Podemos lograr lo mejor de ambos mundos, manteniendo eficiencia de parámetros mientras se desempeña bien con datos de entrenamiento limitados?"

Contribuciones Principales

  1. Propuesta del Marco TabDistill: Una nueva estrategia para destilar conocimiento de modelos Transformer en redes neuronales, logrando clasificación de datos tabulares eficiente en parámetros
  2. Instanciación Dual de Modelos: Implementación del marco basada en TabPFN (~11M parámetros) y BigScience T0pp (~11B parámetros), destilados en MLP de aproximadamente 1000 parámetros
  3. Verificación Experimental: Validación en 5 conjuntos de datos tabulares, donde el MLP destilado supera líneas base clásicas e incluso, en algunos casos, los modelos Transformer originales
  4. Estrategia de Entrenamiento Innovadora: Introducción de técnicas de entrenamiento basadas en permutaciones para evitar sobreajuste en conjuntos de entrenamiento extremadamente pequeños

Explicación Detallada del Método

Definición de la Tarea

Dado un conjunto de datos tabulares de pequeña escala DN={(xn,yn),xnX,yn{0,1},n=1,...,N}D_N = \{(x_n, y_n), x_n \in X, y_n \in \{0,1\}, n=1,...,N\}, donde N10N \sim 10, el objetivo es utilizar el conocimiento del modelo Transformer preentrenado ff para generar un MLP simple hθ(x):X{0,1}h_\theta(x): X \to \{0,1\}.

Arquitectura del Modelo

Marco General

TabDistill contiene dos etapas:

  • Etapa 1: Ajuste fino del modelo Transformer base para generar un MLP de calidad
  • Etapa 2: Ajuste fino adicional opcional del MLP

Componentes Principales

  1. Descomposición del Modelo Base:
    • Codificador: fE(s):SZf_E(s): S \to Z
    • Decodificador: fD(z):Z{0,1}f_D(z): Z \to \{0,1\}
  2. Arquitectura MLP:
    h_θ(x) = ReLU(W_R ReLU(···ReLU(W_2 ReLU(W_1 x + b_1) + b_2)···) + b_R)
    

    donde R es el número de capas y L es el ancho de la capa oculta
  3. Mapeo Lineal:
    m_η(z) = LayerNorm(Az + b)
    

    donde ARdim(Θ)×dim(Z)A \in R^{dim(Θ)×dim(Z)}, η=(A,b)η = (A,b)

Procedimiento de Entrenamiento

Función de Pérdida de la Etapa 1:

L(η; D_N) = Σ[y_n log(σ(h_θ(x_n))[[1]]) + (1-y_n) log(σ(h_θ(x_n))[[0]])]

donde θ=mη(fE(g(DN)))θ = m_η(f_E(g(D_N)))

Puntos de Innovación Técnica

  1. Idea de Hiperred: Inspirada en experiencias de visión por computadora, utilizando Transformer como una hiperred que genera pesos de redes neuronales
  2. Aumento por Permutación: Permutación aleatoria del orden de características en cada época de entrenamiento para evitar sobreajuste
  3. Ajuste Fino Eficiente en Parámetros: Solo se ajustan los parámetros del mapeo lineal ηη, manteniendo los parámetros del modelo base sin cambios
  4. Diseño de Dos Etapas: Destilación seguida de ajuste fino, aprovechando plenamente el conocimiento preentrenado

Instanciaciones Específicas

TabDistill + TabPFN

  • Uso directo de datos tabulares, g(x)=xg(x) = x (transformación identidad)
  • Dimensión de salida del codificador: 192N192N
  • Dimensión de matriz de mapeo: dim(Θ)×192Ndim(Θ) × 192N

TabDistill + T0pp

  • Serialización de texto: "The <column name> is <value>"
  • Dimensión de salida del codificador: 4096
  • Dimensión de matriz de mapeo: dim(Θ)×4096dim(Θ) × 4096

Configuración Experimental

Conjuntos de Datos

Se utilizan 5 conjuntos de datos tabulares públicos:

  1. Bank (UCI Bank Marketing): Predicción de si los clientes se suscribirán a depósitos a plazo
  2. Blood (UCI Blood Transfusion): Predicción de si una persona donará sangre
  3. Calhousing (California Housing): Predicción de si el valor de un bloque de viviendas es alto
  4. Heart (UCI Heart Disease): Predicción de si una persona tiene enfermedad cardíaca
  5. Income (Census Income): Predicción de si los ingresos anuales superan 50K

Métricas de Evaluación

Se utiliza ROC-AUC como métrica de evaluación principal, considerando el desempeño de clasificación en escenarios de pocos ejemplos.

Métodos de Comparación

  1. Líneas Base Clásicas: Regresión logística, XGBoost, MLP entrenado independientemente
  2. Modelos Base: TabPFN, T0pp (TabLLM)
  3. Modelos Destilados: TabDistill + TabPFN, TabDistill + T0pp

Detalles de Implementación

  • Arquitectura MLP: 4 capas, 10 neuronas por capa (aproximadamente 1000 parámetros)
  • Configuración de Entrenamiento: Etapa 1 con 300 épocas de ajuste fino, Etapa 2 con 100 épocas adicionales
  • Optimización de Hiperparámetros: Búsqueda en cuadrícula usando Weights & Biases
  • Escala de Muestras: N ∈ {4, 8, 16, 32, 64}

Resultados Experimentales

Resultados Principales

Según los resultados de ROC-AUC de la Tabla 1:

Escenario de Pocos Ejemplos Extremos (N=4)

  • TabDistill + TabPFN alcanza 0.72 en el conjunto de datos Bank, superando significativamente todas las líneas base clásicas
  • TabDistill + T0pp muestra un desempeño excelente en múltiples conjuntos de datos, como Calhousing (0.67) e Income (0.70)

Tendencias de Desempeño

  1. Mejora de Desempeño con Aumento de Muestras: Todos los métodos muestran mejora general de desempeño cuando N aumenta
  2. Diferencias en Métodos Base: No existe un único método clásico que sea universalmente óptimo en todos los conjuntos de datos
  3. Diferencias en Selección de Modelos: TabDistill + TabPFN es generalmente superior a TabDistill + T0pp, pero lo opuesto ocurre en el conjunto de datos Income

Comparación con Modelos Base

La Tabla 3 muestra resultados sorprendentes:

  • En algunos casos, el MLP destilado supera al modelo Transformer original
  • Por ejemplo, en el conjunto de datos Bank con N=4: TabDistill + TabPFN (0.72) > TabPFN (0.62)
  • Esto indica que el proceso de destilación no solo comprime el modelo, sino que también puede mejorar el desempeño

Estudios de Ablación

Impacto de la Complejidad del Modelo (Tabla 2)

  • Prueba del impacto del número de capas R en el desempeño
  • Los resultados muestran que el desempeño disminuye cuando la complejidad supera cierto umbral
  • La arquitectura de 4 capas muestra el mejor desempeño en la mayoría de los casos

Análisis de Atribución de Características (Figura 3)

Análisis de importancia de características usando SHAP:

  • El modelo destilado mantiene consistencia con líneas base clásicas en importancia de características
  • Incluso después de permutación de características, el modelo identifica correctamente características importantes
  • Demuestra que el modelo base aprendió correctamente la asociación entre pesos de MLP y orden de características

Hallazgos Experimentales

  1. Efecto de Destilación Significativo: En escenarios de pocos ejemplos extremos, los modelos destilados son claramente superiores a métodos clásicos
  2. Eficiencia de Parámetros: Compresión de millones/miles de millones de parámetros a nivel de miles, con mejora masiva en eficiencia
  3. Transferencia de Conocimiento Efectiva: El conocimiento preentrenado se transfiere exitosamente a MLP simple
  4. Robustez Excelente: La estrategia de aumento por permutación previene efectivamente el sobreajuste

Trabajo Relacionado

Algoritmos Clásicos para Datos Tabulares

  • Ventajas Tradicionales: XGBoost, LightGBM, CatBoost han dominado el campo de datos tabulares durante mucho tiempo
  • Limitaciones en Pocos Ejemplos: Los modelos clásicos entrenados desde cero muestran desempeño significativamente reducido en escenarios de pocos ejemplos

Aplicaciones de Transformers en Datos Tabulares

  • SAINT: Utiliza mecanismos de atención para modelar interacciones fila-columna, introduciendo preentrenamiento autosupervisado
  • TabPFN: Preentrenado en grandes cantidades de datos tabulares sintéticos, capaz de predicción en nuevas tareas sin entrenamiento adicional
  • Serie TabLLM: Serializa datos tabulares a texto, aprovechando LLM para clasificación

Metaaprendizaje e Hiperedes

  • Conexión con Metaaprendizaje: Los Transformers son expertos en aprendizaje en contexto, similar al paradigma de metaaprendizaje
  • Aplicación de Hiperedes: Trabajos previos en visión por computadora han utilizado Transformers para generar pesos de redes neuronales
  • Innovación de este Trabajo: Primera aplicación de esta idea al campo de datos tabulares

Destilación de Conocimiento

  • Destilación Tradicional: Alineación de salidas del modelo estudiante y maestro a través de funciones de pérdida
  • Diferencia de este Trabajo: Extracción directa de redes neuronales de Transformers sin necesidad de alineación de pérdidas

Conclusiones y Discusión

Conclusiones Principales

  1. Validación de Efectividad: TabDistill logra exitosamente el equilibrio entre eficiencia de parámetros y desempeño en pocos ejemplos
  2. Ventajas de Desempeño: El MLP destilado supera líneas base clásicas en la mayoría de los casos, e incluso supera Transformers originales en algunos escenarios
  3. Valor Práctico: Proporciona una solución prácticamente desplegable que satisface necesidades de diferentes niveles de infraestructura

Limitaciones

Los autores señalan honestamente las siguientes deficiencias:

  1. Desempeño con Muestras Grandes: Mejora limitada cuando aumenta el número de muestras de entrenamiento
  2. Función de Mapeo Simple: El uso actual de mapeo lineal simple puede limitar el techo de desempeño
  3. Herencia de Sesgos: Los modelos destilados pueden heredar sesgos de modelos base
  4. Alcance de Aplicación: Actualmente solo validado en tareas de clasificación binaria

Direcciones Futuras

  1. Mejora de Función de Mapeo: Exploración de funciones de mapeo más complejas para mejorar desempeño
  2. Extensión de Aplicaciones: Extensión a razonamiento en lenguaje natural, ajuste fino de instrucciones y otras tareas de pocos ejemplos
  3. Mitigación de Sesgos: Reducción de sesgos de modelos base mediante ajuste fino de MLP en la segunda etapa
  4. Aprendizaje Multitarea: Exploración de posibilidades de manejo simultáneo de múltiples tareas tabulares

Evaluación Profunda

Fortalezas

  1. Especificidad del Problema: Identifica y resuelve con precisión la contradicción central en aplicaciones prácticas
  2. Innovación del Método: Primera aplicación de la idea de hiperred a destilación de datos tabulares
  3. Diseño Experimental Completo:
    • Validación en múltiples conjuntos de datos
    • Comparación exhaustiva de líneas base
    • Estudios de ablación detallados
    • Análisis de atribución de características
  4. Resultados Convincentes: No solo logra objetivos esperados, sino que descubre el fenómeno interesante de que modelos destilados superan modelos originales
  5. Alto Valor Práctico: Proporciona una solución directamente aplicable

Deficiencias

  1. Análisis Teórico Insuficiente: Falta explicación teórica de por qué los modelos destilados pueden superar modelos originales
  2. Escala de Conjuntos de Datos Limitada: Validación solo en 5 conjuntos de datos relativamente pequeños
  3. Tipo de Tarea Único: Solo considera clasificación binaria, sin involucrar regresión o multiclasificación
  4. Selección de Modelos Base Limitada: Solo prueba dos modelos base, cobertura limitada
  5. Análisis de Costo Computacional: Falta comparación detallada de costos computacionales reales de entrenamiento e inferencia

Impacto

  1. Contribución Académica:
    • Abre nueva dirección en destilación de Transformers para datos tabulares
    • Proporciona nuevo enfoque para aprendizaje en pocos ejemplos
    • Conecta dos campos de investigación: hiperedes y destilación de conocimiento
  2. Valor Práctico:
    • Resuelve problema importante en despliegue práctico
    • Proporciona solución viable para entornos con recursos limitados
    • Directamente aplicable a escenarios industriales
  3. Reproducibilidad:
    • Proporciona detalles de implementación detallados
    • Compromiso de código abierto mejora reproducibilidad
    • Configuración experimental clara y repetible

Escenarios de Aplicación

  1. Entornos con Recursos Limitados: Dispositivos móviles, computación de borde y otros escenarios
  2. Aplicaciones de Pocos Ejemplos: Diagnóstico médico, control de riesgos financieros, inspección de calidad y otros campos con datos escasos
  3. Necesidades de Inferencia en Tiempo Real: Servicios en línea que requieren respuesta rápida
  4. Requisitos de Interpretabilidad de Modelos: Comparado con Transformers complejos, MLP simple es más fácil de interpretar

Referencias

El artículo cita trabajos relacionados abundantes, incluyendo principalmente:

  • Métodos clásicos para datos tabulares: XGBoost, LightGBM, CatBoost, etc.
  • Aplicaciones de Transformers en tablas: TabPFN, SAINT, serie TabLLM
  • Destilación de conocimiento: Trabajos clásicos de Hinton et al.
  • Hiperedes: Aplicaciones relacionadas en visión por computadora
  • Metaaprendizaje: Investigación relacionada con aprendizaje en contexto de Transformers

Evaluación General: Este es un artículo de investigación de alta calidad que propone una solución innovadora a un problema práctico, con verificación experimental suficiente y valor académico y práctico importante. Aunque existen algunas limitaciones, ha hecho contribuciones significativas al desarrollo del campo relacionado.