TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification
Dissanayake, Dutta
Transformer-based models have shown promising performance on tabular data compared to their classical counterparts such as neural networks and Gradient Boosted Decision Trees (GBDTs) in scenarios with limited training data. They utilize their pre-trained knowledge to adapt to new domains, achieving commendable performance with only a few training examples, also called the few-shot regime. However, the performance gain in the few-shot regime comes at the expense of significantly increased complexity and number of parameters. To circumvent this trade-off, we introduce TabDistill, a new strategy to distill the pre-trained knowledge in complex transformer-based models into simpler neural networks for effectively classifying tabular data. Our framework yields the best of both worlds: being parameter-efficient while performing well with limited training data. The distilled neural networks surpass classical baselines such as regular neural networks, XGBoost and logistic regression under equal training data, and in some cases, even the original transformer-based models that they were distilled from.
academic
TabDistill: Дистилляция трансформаторов в нейронные сети для классификации табличных данных в режиме few-shot
Модели на основе трансформаторов продемонстрировали перспективную производительность на табличных данных по сравнению с классическими методами, такими как нейронные сети и градиентные бустированные деревья решений (GBDT), в сценариях с ограниченными данными обучения. Они используют предварительно полученные знания для адаптации к новым областям, достигая хороших результатов с несколькими примерами обучения, что называется режимом few-shot. Однако улучшение производительности в режиме few-shot достигается за счет значительного увеличения сложности и количества параметров. Чтобы избежать этого компромисса, мы представляем TabDistill — новую стратегию дистилляции предварительно полученных знаний из сложных моделей на основе трансформаторов в более простые нейронные сети для эффективной классификации табличных данных. Наша система обеспечивает лучшее из обоих миров: параметрическую эффективность при хорошей производительности с ограниченными данными обучения. Дистиллированные нейронные сети превосходят классические базовые методы, такие как обычные нейронные сети, XGBoost и логистическая регрессия при равном объеме данных обучения, а в некоторых случаях даже исходные модели на основе трансформаторов, из которых они были дистиллированы.
Данное исследование решает центральное противоречие в классификации табличных данных: в сценариях few-shot модели на основе трансформаторов, хотя и показывают отличную производительность, имеют огромное количество параметров и высокую вычислительную сложность, что затрудняет их развертывание в практических приложениях.
Практические требования приложений: В высокорисковых областях, таких как финансы, здравоохранение и производство, дефицит аннотированных данных является распространённой проблемой, например при диагностике редких заболеваний или прогнозировании столетних природных явлений
Стоимость аннотирования данных: В финансовых приложениях аннотирование данных дорого, подвержено субъективности, ошибкам и отсутствию консенсуса
Ограничения развертывания: Практические приложения требуют параметрически эффективных и масштабируемых моделей, адаптированных к различным уровням инфраструктуры
Традиционные методы: XGBoost, CatBoost, LightGBM показывают отличные результаты при достаточном количестве данных, но значительно теряют в производительности в сценариях few-shot
Методы на основе трансформаторов: TabPFN, TabLLM и другие показывают отличные результаты в режиме few-shot, но имеют параметры на уровне миллионов или даже миллиардов, что приводит к высоким затратам на вывод
Компромисс эффективность-производительность: Отсутствуют решения, которые одновременно сохраняют производительность few-shot и обеспечивают параметрическую эффективность
Авторы ставят центральный вопрос: "Можно ли достичь лучшего из обоих миров, сохраняя параметрическую эффективность и хорошую производительность с ограниченными данными обучения?"
Предложение системы TabDistill: Новая стратегия дистилляции знаний из моделей трансформаторов в нейронные сети, обеспечивающая параметрически эффективную классификацию табличных данных
Двойная реализация модели: Реализация системы на основе TabPFN (~11M параметров) и BigScience T0pp (~11B параметров) с дистилляцией в MLP с ~1000 параметрами
Экспериментальная проверка: Верификация на 5 табличных наборах данных, где дистиллированные MLP превосходят классические базовые методы и в некоторых случаях даже исходные модели трансформаторов
Инновационная стратегия обучения: Введение техники обучения на основе перестановок для избежания переобучения на экстремально малых наборах обучения
Дан небольшой набор табличных данных DN={(xn,yn),xn∈X,yn∈{0,1},n=1,...,N}, где N∼10. Цель состоит в использовании знаний из предварительно обученной модели трансформатора f для генерации простой MLP hθ(x):X→{0,1}.
Проверка эффективности: TabDistill успешно достигает баланса между параметрической эффективностью и производительностью few-shot
Преимущества производительности: Дистиллированная MLP в большинстве случаев превосходит классические базовые методы, а в некоторых сценариях даже исходный трансформатор
Практическая ценность: Предоставляет практически развёртываемое решение, удовлетворяющее различным требованиям инфраструктуры
Высокая целевая направленность проблемы: Точное выявление и решение центрального противоречия в практических приложениях
Инновационность метода: Первое применение идеи гиперсетей к дистилляции табличных данных
Полнота экспериментального дизайна:
Верификация на нескольких наборах данных
Достаточное сравнение с базовыми методами
Подробные абляционные исследования
Анализ атрибуции признаков
Убедительные результаты: Не только достижение ожидаемых целей, но и обнаружение интересного явления, когда дистиллированная модель превосходит исходную
Высокая практическая ценность: Предоставление непосредственно применимого решения
Статья цитирует богатый объём связанных работ, включая в основном:
Классические методы табличных данных: XGBoost, LightGBM, CatBoost и другие
Применение трансформаторов к табличным данным: TabPFN, SAINT, серия TabLLM
Дистилляция знаний: классические работы Hinton и других
Гиперсети: связанные приложения в компьютерном зрении
Метаобучение: исследования контекстного обучения трансформаторов
Общая оценка: Это высококачественная исследовательская статья, предлагающая инновационное решение практической проблемы с достаточной экспериментальной верификацией, обладающая значительной академической и практической ценностью. Несмотря на некоторые ограничения, она вносит важный вклад в развитие соответствующих областей исследований.