TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification
Dissanayake, Dutta
Transformer-based models have shown promising performance on tabular data compared to their classical counterparts such as neural networks and Gradient Boosted Decision Trees (GBDTs) in scenarios with limited training data. They utilize their pre-trained knowledge to adapt to new domains, achieving commendable performance with only a few training examples, also called the few-shot regime. However, the performance gain in the few-shot regime comes at the expense of significantly increased complexity and number of parameters. To circumvent this trade-off, we introduce TabDistill, a new strategy to distill the pre-trained knowledge in complex transformer-based models into simpler neural networks for effectively classifying tabular data. Our framework yields the best of both worlds: being parameter-efficient while performing well with limited training data. The distilled neural networks surpass classical baselines such as regular neural networks, XGBoost and logistic regression under equal training data, and in some cases, even the original transformer-based models that they were distilled from.
academic
TabDistill: Destillation von Transformern in neuronale Netze für Few-Shot-Tabellenklassifikation
Transformer-basierte Modelle haben bei Tabellendaten im Vergleich zu klassischen Gegenstücken wie neuronalen Netzen und Gradient Boosted Decision Trees (GBDTs) in Szenarien mit begrenzten Trainingsdaten vielversprechende Leistungen gezeigt. Sie nutzen ihr vortrainiertes Wissen, um sich an neue Domänen anzupassen und erzielen beachtliche Leistungen mit nur wenigen Trainingsbeispielen, auch Few-Shot-Regime genannt. Allerdings geht der Leistungsgewinn im Few-Shot-Regime auf Kosten einer erheblich erhöhten Komplexität und Parameteranzahl. Um diesen Kompromiss zu vermeiden, stellen wir TabDistill vor, eine neue Strategie zur Destillation des vortrainierten Wissens in komplexen Transformer-basierten Modellen in einfachere neuronale Netze zur effektiven Klassifikation von Tabellendaten. Unser Framework bietet das Beste aus beiden Welten: Parametereffizient zu sein und gleichzeitig mit begrenzten Trainingsdaten gut zu funktionieren. Die destillierten neuronalen Netze übertreffen klassische Baselines wie reguläre neuronale Netze, XGBoost und logistische Regression bei gleichen Trainingsdaten und übersteigen in einigen Fällen sogar die ursprünglichen Transformer-basierten Modelle, aus denen sie destilliert wurden.
Diese Forschung befasst sich mit einem grundlegenden Widerspruch bei der Klassifikation von Tabellendaten: Im Few-Shot-Szenario zeigen Transformer-basierte Modelle zwar hervorragende Leistungen, verfügen aber über eine enorme Parameteranzahl und hohe Rechenkomplexität, was ihre praktische Bereitstellung erschwert.
Praktische Anforderungen: In hochriskanten Bereichen wie Finanzen, Medizin und Fertigung ist die Knappheit annotierter Daten ein häufiges Problem, wie bei der Diagnose seltener Krankheiten oder der Vorhersage hundertjähriger Naturphänomene
Kosten der Datenannotation: In Finanzanwendungen ist die Datenannotation teuer und unterliegt Subjektivität, Annotationsfehlern und mangelndem Konsens
Bereitstellungsbeschränkungen: Praktische Anwendungen erfordern parametereffiziente und skalierbare Modelle, um sich an unterschiedliche Infrastrukturniveaus anzupassen
Traditionelle Methoden: XGBoost, CatBoost, LightGBM zeigen bei ausreichenden Daten hervorragende Leistungen, aber ihre Leistung sinkt im Few-Shot-Szenario erheblich
Transformer-Methoden: TabPFN, TabLLM und ähnliche zeigen im Few-Shot-Szenario hervorragende Leistungen, verfügen aber über Millionen bis Milliarden Parameter, was hohe Inferenzkosten verursacht
Effizienz-Leistungs-Kompromiss: Es fehlt eine Lösung, die sowohl Few-Shot-Leistung als auch Parametereffizienz bewahrt
Die Autoren stellen die zentrale Frage: "Können wir das Beste aus beiden Welten erreichen – Parametereffizienz bewahren und gleichzeitig mit begrenzten Trainingsdaten gut funktionieren?"
Vorstellung des TabDistill-Frameworks: Eine neue Strategie zur Destillation von Transformer-Wissen in neuronale Netze, um parametereffiziente Tabellenklassifikation zu erreichen
Zwei Modellinstanziierungen: Framework-Implementierung basierend auf TabPFN (~11M Parameter) und BigScience T0pp (~11B Parameter), destilliert zu MLPs mit etwa 1000 Parametern
Experimentelle Validierung: Validierung auf 5 Tabellendatensätzen zeigt, dass destillierte MLPs klassische Baselines übertreffen und in einigen Fällen sogar die ursprünglichen Transformer-Modelle übertreffen
Innovative Trainingsstrategie: Einführung permutationsbasierter Trainingstechniken zur Vermeidung von Überanpassung bei extrem kleinen Trainingsmengen
Gegeben ein kleiner Tabellendatensatz DN={(xn,yn),xn∈X,yn∈{0,1},n=1,...,N}, wobei N∼10, besteht das Ziel darin, das Wissen eines vortrainierten Transformer-Modells f zu nutzen, um ein einfaches MLP hθ(x):X→{0,1} zu generieren.
Hypernetwork-Idee: Inspiriert von Erfahrungen in der Computervision wird der Transformer als Hypernetwork zur Generierung von Neuronennetzgewichten verwendet
Permutationserweiterung: Zufällige Permutation der Merkmalsreihenfolge in jedem Trainings-Epoch zur Vermeidung von Überanpassung
Parametereffiziente Feinabstimmung: Nur lineare Abbildungsparameter η werden feinabgestimmt, während Basismodellparameter unverändert bleiben
Zweiphasen-Design: Erst destillieren, dann feinabstimmen, um vortrainiertes Wissen vollständig zu nutzen
Validierung der Effektivität: TabDistill erreicht erfolgreich das Gleichgewicht zwischen Parametereffizienz und Few-Shot-Leistung
Leistungsvorteil: Destillierte MLPs übertreffen in den meisten Fällen klassische Baselines und übersteigen in einigen Szenarien sogar die ursprünglichen Transformer
Praktischer Wert: Bietet eine praktisch einsetzbare Lösung, die unterschiedliche Infrastrukturanforderungen erfüllt
Starke Problemorientierung: Genaue Identifikation und Lösung des Kernwiderspruchs in praktischen Anwendungen
Methodische Innovation: Erste Anwendung der Hypernetwork-Idee auf Tabellendata-Destillation
Vollständiges Experimentdesign:
Validierung auf mehreren Datensätzen
Umfassende Baseline-Vergleiche
Detaillierte Ablationsstudien
Merkmalsattributionsanalyse
Überzeugende Ergebnisse: Nicht nur erwartete Ziele erreicht, sondern auch interessantes Phänomen entdeckt, dass destillierte Modelle Originalmodelle übertreffen können
Hoher praktischer Wert: Bietet direkt anwendbare Lösungen
Das Papier zitiert umfangreiche verwandte Arbeiten, hauptsächlich einschließlich:
Klassische Methoden für Tabellendaten: XGBoost, LightGBM, CatBoost usw.
Transformer-Anwendungen auf Tabellen: TabPFN, SAINT, TabLLM-Serie
WissensDestillation: Klassische Arbeiten von Hinton usw.
Hypernetworks: Verwandte Anwendungen in der Computervision
Meta-Learning: Forschung zu Transformer-In-Context-Learning
Gesamtbewertung: Dies ist ein hochqualitatives Forschungspapier, das eine innovative Lösung für praktische Probleme bietet, umfassend experimentell validiert ist und sowohl akademischen als auch praktischen Wert hat. Obwohl es einige Einschränkungen gibt, trägt es wichtig zur Entwicklung verwandter Bereiche bei.