TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification
Dissanayake, Dutta
Transformer-based models have shown promising performance on tabular data compared to their classical counterparts such as neural networks and Gradient Boosted Decision Trees (GBDTs) in scenarios with limited training data. They utilize their pre-trained knowledge to adapt to new domains, achieving commendable performance with only a few training examples, also called the few-shot regime. However, the performance gain in the few-shot regime comes at the expense of significantly increased complexity and number of parameters. To circumvent this trade-off, we introduce TabDistill, a new strategy to distill the pre-trained knowledge in complex transformer-based models into simpler neural networks for effectively classifying tabular data. Our framework yields the best of both worlds: being parameter-efficient while performing well with limited training data. The distilled neural networks surpass classical baselines such as regular neural networks, XGBoost and logistic regression under equal training data, and in some cases, even the original transformer-based models that they were distilled from.
academic
TabDistill: Distillazione di Transformer in Reti Neurali per la Classificazione Tabulare Few-Shot
I modelli basati su Transformer hanno dimostrato prestazioni promettenti sui dati tabulari rispetto ai loro equivalenti classici come le reti neurali e gli Alberi di Decisione Potenziati da Gradiente (GBDT) in scenari con dati di addestramento limitati. Utilizzano la loro conoscenza pre-addestrata per adattarsi a nuovi domini, ottenendo prestazioni lodevoli con solo pochi esempi di addestramento, noto anche come regime few-shot. Tuttavia, il guadagno di prestazioni nel regime few-shot avviene a scapito di una complessità significativamente aumentata e di un numero di parametri. Per evitare questo compromesso, introduciamo TabDistill, una nuova strategia per distillare la conoscenza pre-addestrata in modelli complessi basati su Transformer in reti neurali più semplici per classificare efficacemente i dati tabulari. Il nostro framework offre il meglio di entrambi i mondi: essere efficiente in termini di parametri mantenendo buone prestazioni con dati di addestramento limitati. Le reti neurali distillate superano i baseline classici come le reti neurali regolari, XGBoost e la regressione logistica con pari dati di addestramento, e in alcuni casi, persino i modelli originali basati su Transformer da cui sono stati distillati.
Questa ricerca affronta una contraddizione fondamentale nella classificazione di dati tabulari: negli scenari few-shot, i modelli basati su Transformer, sebbene performanti, hanno un numero enorme di parametri e un'elevata complessità computazionale, rendendo difficile il loro dispiegamento nelle applicazioni pratiche.
Esigenze Applicative Pratiche: In settori ad alto rischio come finanza, medicina e manifattura, la scarsità di dati annotati è un problema comune, come nella diagnosi di malattie rare o nella previsione di fenomeni naturali centenari
Costi di Annotazione dei Dati: Nelle applicazioni finanziarie l'annotazione dei dati è costosa, con problemi di soggettività, annotazioni errate e mancanza di consenso
Vincoli di Dispiegamento: Le applicazioni pratiche richiedono modelli efficienti in termini di parametri e scalabili, per adattarsi a diversi livelli di infrastruttura
Metodi Tradizionali: XGBoost, CatBoost, LightGBM mostrano prestazioni eccellenti con dati sufficienti, ma le prestazioni diminuiscono significativamente negli scenari few-shot
Metodi Transformer: TabPFN, TabLLM e simili mostrano prestazioni eccellenti negli scenari few-shot, ma hanno parametri che raggiungono milioni o persino miliardi, con costi di inferenza elevati
Compromesso Efficienza-Prestazioni: Manca una soluzione che mantenga le prestazioni few-shot e l'efficienza dei parametri simultaneamente
Gli autori pongono la domanda centrale: "Possiamo ottenere il meglio di entrambi i mondi, mantenendo l'efficienza dei parametri e mostrando buone prestazioni con dati di addestramento limitati?"
Proposta del Framework TabDistill: Una nuova strategia per distillare la conoscenza dai modelli Transformer in reti neurali, realizzando una classificazione tabulare efficiente in termini di parametri
Istanziazione Dual-Model: Implementazione del framework basata su TabPFN (~11M parametri) e BigScience T0pp (~11B parametri), distillati in MLP di circa 1000 parametri
Verifica Sperimentale: Validazione su 5 dataset tabulari, con MLP distillati che superano i baseline classici e in alcuni casi persino i modelli Transformer originali
Strategia di Addestramento Innovativa: Introduzione di tecniche di addestramento basate su permutazioni per evitare l'overfitting su insiemi di addestramento estremamente piccoli
Dato un piccolo dataset tabulare DN={(xn,yn),xn∈X,yn∈{0,1},n=1,...,N}, dove N∼10, l'obiettivo è utilizzare la conoscenza del modello Transformer pre-addestrato f per generare un semplice MLP hθ(x):X→{0,1}.
Validazione dell'Efficacia: TabDistill realizza con successo l'equilibrio tra efficienza dei parametri e prestazioni few-shot
Vantaggi di Prestazione: Gli MLP distillati superano nella maggior parte dei casi i baseline classici, e in alcuni scenari persino i Transformer originali
Valore Pratico: Fornisce una soluzione praticamente distribuibile che soddisfa le esigenze di diverse infrastrutture
Forte Specificità del Problema: Identifica e risolve accuratamente la contraddizione fondamentale nelle applicazioni pratiche
Innovazione del Metodo: Prima applicazione dell'idea di iper-rete alla distillazione di dati tabulari
Design Sperimentale Completo:
Validazione su più dataset
Confronti baseline sufficienti
Esperimenti di ablazione dettagliati
Analisi dell'attribuzione delle caratteristiche
Risultati Convincenti: Non solo realizza gli obiettivi previsti, ma scopre anche il fenomeno interessante che i modelli distillati superano i modelli originali
Alto Valore Pratico: Fornisce una soluzione direttamente applicabile
L'articolo cita lavori correlati ricchi, principalmente includenti:
Metodi classici per dati tabulari: XGBoost, LightGBM, CatBoost, ecc.
Applicazioni Transformer tabulari: TabPFN, SAINT, serie TabLLM
Distillazione di Conoscenza: Lavori classici di Hinton e altri
Iper-Reti: Applicazioni correlate nella visione artificiale
Meta-Apprendimento: Ricerca correlata sull'apprendimento in contesto di Transformer
Valutazione Complessiva: Questo è un articolo di ricerca di alta qualità che propone una soluzione innovativa a problemi pratici, con verifica sperimentale sufficiente e significativo valore accademico e pratico. Sebbene presenti alcune limitazioni, ha fornito importanti contributi allo sviluppo dei campi correlati.