Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond
Oncescu, Purandare, Idreos et al.
While transformers have been at the core of most recent advancements in sequence generative models, their computational cost remains quadratic in sequence length. Several subquadratic architectures have been proposed to address this computational issue. Some of them, including long convolution sequence models (LCSMs), such as Hyena, address this issue at training time but remain quadratic during inference. We propose a method for speeding up LCSMs' exact inference to quasilinear $O(L\log^2L)$ time, identify the key properties that make this possible, and propose a general framework that exploits these. Our approach, inspired by previous work on relaxed polynomial interpolation, is based on a tiling which helps decrease memory movement and share computation. It has the added benefit of allowing for almost complete parallelization across layers of the position-mixing part of the architecture. Empirically, we provide a proof of concept implementation for Hyena, which gets up to $7.8\times$ end-to-end improvement over standard inference by improving $110\times$ within the position-mixing part.
academic
Flash Inference: Вывод за почти линейное время для длинных сверточных последовательностных моделей и не только
В данной работе предлагается структура Flash Inference для решения проблемы квадратичной временной сложности при выводе длинных сверточных последовательностных моделей (LCSM). Метод снижает временную сложность точного вывода до квазилинейной O(Llog2L). Подход вдохновлен релаксированной полиномиальной интерполяцией и основан на стратегии разбиения на блоки (tiling) для уменьшения перемещения данных в памяти и совместного использования вычислений. Эксперименты на архитектуре Hyena демонстрируют 7,8-кратное ускорение сквозного вывода и 110-кратное ускорение части позиционного смешивания.
Хотя Transformer достиг огромного успеха в моделях генерации последовательностей, его вычислительная стоимость растет квадратично с длиной последовательности (O(L2)), что становится узким местом как на этапе обучения, так и на этапе вывода. Для решения этой проблемы исследователи предложили различные субквадратичные архитектуры, включая модели пространства состояний (SSM) и длинные сверточные последовательностные модели (LCSM, такие как Hyena).
Эффективность обучения решена: LCSM могут достичь сложности O(LlogL) во время обучения благодаря БПФ
Эффективность вывода не решена: При автогрессивном выводе входная последовательность генерируется пошагово, что не позволяет напрямую использовать БПФ, что приводит к деградации сложности до O(L2)
Требования к длинному контексту: По мере того как большие языковые модели обрабатывают все более длинные контексты, проблема эффективности вывода становится еще более острой
Приближенные методы (Massaroli et al. 2024): Проецируют сверточный фильтр на низкомерную LTI SSM, но это только приближение, требующее дорогостоящего предварительного вычисления дистилляции и не поддерживающее зависящие от данных фильтры
Рекурсивный подход: Может быть эффективным для низкомерных SSM, но остается неэффективным для высокомерных SSM (размерность близка к длине последовательности)
Методы использования структуры: Требуют, чтобы фильтры имели специфическую структуру (например, низкоранговую LTI SSM), что ограничивает выразительность модели
Данная работа направлена на предоставление точного и универсального фреймворка ускорения вывода, не зависящего от специфической структуры фильтра и поддерживающего зависящие от данных фильтры.
Первый квазилинейный алгоритм точного вывода: Предложен алгоритм точного вывода для LCSM с временной сложностью O(Llog2L), что достигает точного моделирования в отличие от предыдущих приближенных методов
Идентификация универсального фреймворка: Определены ключевые архитектурные свойства, делающие быстрый вывод возможным (основанные на вкладе, независимые от запроса), и предложен фреймворк Flash Inference, применимый к более широкому классу архитектур
Параллелизм между слоями: Использование стратегии разбиения на блоки для реализации почти полного параллельного вычисления части позиционного смешивания между слоями
Оптимизация памяти: Через метод разбиения на блоки значительно снижается перемещение данных с Ω(L2) до O(LlogL), экономя 2-кратное хранилище активаций для фильтров, независимых от данных
Эмпирическая проверка: На архитектуре Hyena реализовано 7,8-кратное ускорение сквозного вывода и 110-кратное ускорение части свертки
Автогрессивная генерация последовательности: Учитывая последовательность подсказки x1,…,xp, модель должна пошагово генерировать последующие токены. На каждой позиции i модель вычисляет активации ai[1,M] через все слои, наконец выполняя выборку для генерации xi+1 из aiM.
Основное вычислительное узкое место: Для каждого слоя ℓ и каждого измерения необходимо вычислить:
zt=∑i=1tyi⋅ρt−i
где y — входная последовательность, ρ — сверточный фильтр длины L. Наивная реализация требует Ω(L2) времени.
Расширение Алгоритма 1 на многослойный многомерный случай:
for i = 1 to L-1:
U ← наибольшая степень 2, делящая i
for ℓ = 1 to M: # итерация по слоям
b^ℓ_i += a^{ℓ-1}_i ⊙ ρ^ℓ_0 # красная ячейка
a^ℓ_i = block^ℓ(b^ℓ_i)
b^ℓ[i+1:i+U] += τ(a^{ℓ-1}, [i-U+1, i], ρ^ℓ, [i+1, i+U]) # серый блок
a^0_{i+1} = sampler(a^M_i)
Вычисление серых блоков может выполняться параллельно по всем слоям:
for i = 1 to L-1:
for ℓ = 1 to M:
обработка красных ячеек (должна быть последовательной)
parallel for ℓ = 1 to M:
обработка серых блоков (может быть параллельной)
Преимущества:
Маленькие блоки (87,5% блоков размером ≤4) обычно ограничены задержкой памяти, параллелизм может насытить пропускную способность памяти
Большие блоки используют БПФ, вычислительно интенсивны, параллелизм повышает пропускную способность
Путем модификации стратегии разбиения на блоки (Алгоритм 5) поддерживается случай, когда ρ зависит от данных, с затратой в 2-кратное увеличение вычислений.
CUDA Graphs: запись всех вызовов ядер для генерации одного токена в виде графика, последующее воспроизведение для снижения накладных расходов CPU (улучшение на 10-20%)
Предварительное вычисление БПФ: предварительное вычисление ДПФ сверточного ядра для log2(L)−1 размеров блоков
Предварительная конфигурация FlashFFT: предварительная инициализация конфигурации для разных размеров блоков для максимизации производительности оборудования
Правое заполнение: использование правого заполнения вместо левого, снижение времени вычисления в два раза
Циклическая свертка: использование свойств циклической свертки для уменьшения длины БПФ в два раза
Теория и практика согласуются: сложность O(Llog2L) проявляется в экспериментах как значительное ускорение
Важность пропускной способности памяти: Flash Conv1D хотя и имеет квадратичную сложность, но через оптимизацию доступа к памяти все еще достигает 5-кратного ускорения
Необходимость динамического выбора: нет единственной реализации τ, оптимальной для всех размеров блоков, стратегия Hybrid критична
Накладные расходы CPU не пренебрежимы: CUDA Graphs повышает сквозное ускорение с 1,6× до 8×
Выгода параллелизма: маленькие блоки доминируют (87,5%), параллелизм между слоями дает значительный эффект
Зависящие от данных фильтры: хотя теоретически поддерживаются, требуют 2-кратное увеличение вычислений, экспериментально не полностью проверены
Требования к памяти: все еще требуется хранение всех активаций O(MLD) (vs рекурсивный подход O(MD′))
Область применения:
Не применимо к Transformer (не удовлетворяет независимости от запроса)
Для очень низкомерных SSM (D′≪log2L) рекурсивный подход может быть более оптимальным
Этап подсказки: при длинных подсказках предварительное заполнение (prefill) все еще доминирует во времени, оптимизация автогрессивной генерации имеет относительно ограниченную выгоду
Зависимость от оборудования: эффект ускорения зависит от характеристик пропускной способности памяти GPU
Проектирование архитектур: разработка новых архитектур, удовлетворяющих требованиям Flash Inference и обеспечивающих высокое качество
Причинные зависящие от данных фильтры: как сделать фильтры зависящими от данных при сохранении причинности (Arora et al., Karami & Ghodsi уже показали потенциал)
Полный анализ сложности: от Леммы 1 к Теореме 2, цепочка доказательств ясна
Надлежащая абстракция универсального фреймворка: свойства P.1 и P.2 хорошо абстрагированы, охватывают LCSM и исключают неприменимые случаи (например, Transformer)
Умелое применение математического инструмента: релаксированная полиномиальная интерполяция применена изящно
van der Hoeven, J. (1997). Lazy multiplication of formal power series. ISSAC. Теоретическая основа
Poli, M. et al. (2023). Hyena hierarchy: Towards larger convolutional language models. Основной объект применения
Massaroli, S. et al. (2024). Laughing hyena distillery: Extracting compact recurrences from convolutions. NeurIPS. Сравнение приближенных методов
Gu, A. & Dao, T. (2023). Mamba: Linear-time sequence modeling with selective state spaces. Связанные работы SSM
Fu, D. Y. et al. (2023). FlashFFTConv: Efficient convolutions for long sequences with tensor cores. Основа реализации
Agarwal, N. et al. (2024). FutureFill: Fast generation from convolutional sequence models. Параллельные работы
Общая оценка: Это отличная статья, тесно интегрирующая теорию и практику. Теоретически она предоставляет первый алгоритм точного вывода O(Llog2L) для LCSM и абстрагирует универсальный фреймворк; практически достигает значительного ускорения через системную оптимизацию. Основные ограничения заключаются в том, что LCSM менее распространены в практических приложениях, чем Transformer, и экспериментальная проверка зависящих от данных фильтров неполна. Данная работа предоставляет новую перспективу оптимизации вывода последовательностных моделей, особенно ценна для проектирования будущих архитектур. Рекомендуется исследователям, интересующимся эффективностью моделей, последовательностным моделированием и системной оптимизацией.