TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification
Dissanayake, Dutta
Transformer-based models have shown promising performance on tabular data compared to their classical counterparts such as neural networks and Gradient Boosted Decision Trees (GBDTs) in scenarios with limited training data. They utilize their pre-trained knowledge to adapt to new domains, achieving commendable performance with only a few training examples, also called the few-shot regime. However, the performance gain in the few-shot regime comes at the expense of significantly increased complexity and number of parameters. To circumvent this trade-off, we introduce TabDistill, a new strategy to distill the pre-trained knowledge in complex transformer-based models into simpler neural networks for effectively classifying tabular data. Our framework yields the best of both worlds: being parameter-efficient while performing well with limited training data. The distilled neural networks surpass classical baselines such as regular neural networks, XGBoost and logistic regression under equal training data, and in some cases, even the original transformer-based models that they were distilled from.
academic
TabDistill: Destilación de Transformers en Redes Neuronales para Clasificación Tabular en Pocos Ejemplos
Los modelos basados en Transformers han demostrado un desempeño prometedor en datos tabulares en comparación con sus contrapartes clásicas, como redes neuronales y Árboles de Decisión Potenciados por Gradiente (GBDT), en escenarios con datos de entrenamiento limitados. Utilizan su conocimiento preentrenado para adaptarse a nuevos dominios, logrando un desempeño encomiable con solo unos pocos ejemplos de entrenamiento, también denominado régimen de pocos ejemplos. Sin embargo, la mejora de desempeño en el régimen de pocos ejemplos tiene un costo de complejidad significativamente aumentada y número de parámetros. Para eludir este compromiso, introducimos TabDistill, una nueva estrategia para destilar el conocimiento preentrenado en modelos complejos basados en Transformers en redes neuronales más simples para clasificar efectivamente datos tabulares. Nuestro marco logra lo mejor de ambos mundos: ser eficiente en parámetros mientras se desempeña bien con datos de entrenamiento limitados. Las redes neuronales destiladas superan líneas base clásicas como redes neuronales regulares, XGBoost y regresión logística bajo igualdad de datos de entrenamiento, y en algunos casos, incluso los modelos originales basados en Transformers de los que fueron destilados.
Esta investigación aborda una contradicción central en la clasificación de datos tabulares: en escenarios de pocos ejemplos, aunque los modelos basados en Transformers muestran un desempeño excelente, poseen una cantidad masiva de parámetros y alta complejidad computacional, lo que dificulta su despliegue en aplicaciones prácticas.
Necesidades de Aplicación Práctica: En campos de alto riesgo como finanzas, medicina y manufactura, la escasez de datos anotados es un problema común, como el diagnóstico de enfermedades raras y la predicción de fenómenos naturales centenarios
Costo de Anotación de Datos: En aplicaciones financieras, la anotación de datos es costosa, con problemas de subjetividad, anotaciones erróneas y falta de consenso
Restricciones de Despliegue: Las aplicaciones prácticas requieren modelos eficientes en parámetros y escalables para adaptarse a diferentes niveles de infraestructura
Métodos Tradicionales: XGBoost, CatBoost, LightGBM, etc., funcionan excelentemente con datos suficientes, pero su desempeño disminuye significativamente en escenarios de pocos ejemplos
Métodos Transformer: TabPFN, TabLLM, etc., muestran un desempeño excepcional en escenarios de pocos ejemplos, pero tienen millones o incluso miles de millones de parámetros, con altos costos de inferencia
Compensación Eficiencia-Desempeño: Falta una solución que mantenga el desempeño en pocos ejemplos mientras proporciona eficiencia de parámetros
Los autores plantean la pregunta central: "¿Podemos lograr lo mejor de ambos mundos, manteniendo eficiencia de parámetros mientras se desempeña bien con datos de entrenamiento limitados?"
Propuesta del Marco TabDistill: Una nueva estrategia para destilar conocimiento de modelos Transformer en redes neuronales, logrando clasificación de datos tabulares eficiente en parámetros
Instanciación Dual de Modelos: Implementación del marco basada en TabPFN (~11M parámetros) y BigScience T0pp (~11B parámetros), destilados en MLP de aproximadamente 1000 parámetros
Verificación Experimental: Validación en 5 conjuntos de datos tabulares, donde el MLP destilado supera líneas base clásicas e incluso, en algunos casos, los modelos Transformer originales
Estrategia de Entrenamiento Innovadora: Introducción de técnicas de entrenamiento basadas en permutaciones para evitar sobreajuste en conjuntos de entrenamiento extremadamente pequeños
Dado un conjunto de datos tabulares de pequeña escala DN={(xn,yn),xn∈X,yn∈{0,1},n=1,...,N}, donde N∼10, el objetivo es utilizar el conocimiento del modelo Transformer preentrenado f para generar un MLP simple hθ(x):X→{0,1}.
Mejora de Desempeño con Aumento de Muestras: Todos los métodos muestran mejora general de desempeño cuando N aumenta
Diferencias en Métodos Base: No existe un único método clásico que sea universalmente óptimo en todos los conjuntos de datos
Diferencias en Selección de Modelos: TabDistill + TabPFN es generalmente superior a TabDistill + T0pp, pero lo opuesto ocurre en el conjunto de datos Income
Ventajas Tradicionales: XGBoost, LightGBM, CatBoost han dominado el campo de datos tabulares durante mucho tiempo
Limitaciones en Pocos Ejemplos: Los modelos clásicos entrenados desde cero muestran desempeño significativamente reducido en escenarios de pocos ejemplos
Validación de Efectividad: TabDistill logra exitosamente el equilibrio entre eficiencia de parámetros y desempeño en pocos ejemplos
Ventajas de Desempeño: El MLP destilado supera líneas base clásicas en la mayoría de los casos, e incluso supera Transformers originales en algunos escenarios
Valor Práctico: Proporciona una solución prácticamente desplegable que satisface necesidades de diferentes niveles de infraestructura
Especificidad del Problema: Identifica y resuelve con precisión la contradicción central en aplicaciones prácticas
Innovación del Método: Primera aplicación de la idea de hiperred a destilación de datos tabulares
Diseño Experimental Completo:
Validación en múltiples conjuntos de datos
Comparación exhaustiva de líneas base
Estudios de ablación detallados
Análisis de atribución de características
Resultados Convincentes: No solo logra objetivos esperados, sino que descubre el fenómeno interesante de que modelos destilados superan modelos originales
Alto Valor Práctico: Proporciona una solución directamente aplicable
El artículo cita trabajos relacionados abundantes, incluyendo principalmente:
Métodos clásicos para datos tabulares: XGBoost, LightGBM, CatBoost, etc.
Aplicaciones de Transformers en tablas: TabPFN, SAINT, serie TabLLM
Destilación de conocimiento: Trabajos clásicos de Hinton et al.
Hiperedes: Aplicaciones relacionadas en visión por computadora
Metaaprendizaje: Investigación relacionada con aprendizaje en contexto de Transformers
Evaluación General: Este es un artículo de investigación de alta calidad que propone una solución innovadora a un problema práctico, con verificación experimental suficiente y valor académico y práctico importante. Aunque existen algunas limitaciones, ha hecho contribuciones significativas al desarrollo del campo relacionado.