Flash Attention-2: Faster Attention with Better Parallelism and Work Partitioning
[Статья][Код]
На днях Tri Dao (это имя и фамилия, а не три пути из китайской философии) выпустил сиквел знаменитого блокбастера Flash Attention - Flash Attention 2. Обновленная версия Flash Attention примерно в два раза быстрее своей предшественницы и местами до 10 раз бывает быстрее наивной реализации PyTorch.
Введение
Потребность в работе с длинным контекстом возникает в ряде приложений - задачах связанных с написанием или пониманием книг, обработкой видео, аудио. Наивная реализация Attention требует квадратичного по длине последовательности объема памяти и вычислений, из-за чего на практике редко используют контекст более 1k токенов. В свое время много работ и усилий исследователей было потрачено на разработку альтернатив алгоритму внимания в его исходной формулировке, но нельзя сказать, чтобы один из множества подходов оказался конкурентоспособен на широком наборе задач.
В работе по Flash Attention добились значительного ускорения Attention и снижения пикового расхода памяти за счет оптимизации входящих в него операции. При этом сам алгоритм математически эквивалентен (up to numerical precision) исходному Attention.
Напомню, что в основе оригинальной работы по Flash Attention лежат следующие наблюдения:
1️⃣ Большую часть времени занимают не вычисления, а работа с памятью.
2️⃣ Для подсчета Attention не обязательно материализовывать всю матрицу Attention, квадратичную по длине последовательности целиком. Можно считать ее поблочно, а затем агрегировать результат.
Flash Attention уменьшает количество операций по перекачке данных с HBM памяти в кэши GPU за счет выполнения нескольких математических операций сразу в пределах одного блока (kernel fusion). При обратном проходе и подсчете градиентов по параметрам матрица Attention пересчитывается снова поблочно. Flash Attention делает больше вычислений, чем исходный алгоритм, но так как вычисления в разы быстрее перекачки памяти туда-сюда, получается значительный выигрыш в производительности.
Однако, даже оптимизированный Flash Attention сильно недоиспользует возможности современных ускорителей вычислений, достигая всего-лишь 25-40% от теоретической максимальной производительности, в то время как матричные умножения могут достигать 80-90%.
[Статья][Код]
На днях Tri Dao (это имя и фамилия, а не три пути из китайской философии) выпустил сиквел знаменитого блокбастера Flash Attention - Flash Attention 2. Обновленная версия Flash Attention примерно в два раза быстрее своей предшественницы и местами до 10 раз бывает быстрее наивной реализации PyTorch.
Введение
Потребность в работе с длинным контекстом возникает в ряде приложений - задачах связанных с написанием или пониманием книг, обработкой видео, аудио. Наивная реализация Attention требует квадратичного по длине последовательности объема памяти и вычислений, из-за чего на практике редко используют контекст более 1k токенов. В свое время много работ и усилий исследователей было потрачено на разработку альтернатив алгоритму внимания в его исходной формулировке, но нельзя сказать, чтобы один из множества подходов оказался конкурентоспособен на широком наборе задач.
В работе по Flash Attention добились значительного ускорения Attention и снижения пикового расхода памяти за счет оптимизации входящих в него операции. При этом сам алгоритм математически эквивалентен (up to numerical precision) исходному Attention.
Напомню, что в основе оригинальной работы по Flash Attention лежат следующие наблюдения:
1️⃣ Большую часть времени занимают не вычисления, а работа с памятью.
2️⃣ Для подсчета Attention не обязательно материализовывать всю матрицу Attention, квадратичную по длине последовательности целиком. Можно считать ее поблочно, а затем агрегировать результат.
Flash Attention уменьшает количество операций по перекачке данных с HBM памяти в кэши GPU за счет выполнения нескольких математических операций сразу в пределах одного блока (kernel fusion). При обратном проходе и подсчете градиентов по параметрам матрица Attention пересчитывается снова поблочно. Flash Attention делает больше вычислений, чем исходный алгоритм, но так как вычисления в разы быстрее перекачки памяти туда-сюда, получается значительный выигрыш в производительности.
Однако, даже оптимизированный Flash Attention сильно недоиспользует возможности современных ускорителей вычислений, достигая всего-лишь 25-40% от теоретической максимальной производительности, в то время как матричные умножения могут достигать 80-90%.
👍1
Метод
В работе Flash Attention 2 по существу еще слегка подкрутили процедуру вычисления и повысили степень параллелизма самой операций.
Алгоритм вычисления
Автор заметил, что операции, не являющиеся матричным умножением, выполняются куда медленнее (в 16 раз), чем матричные умножения, потому переписал алгоритм так, чтобы уменьшить их количество. Казалось бы, их количество невелико, но тем не менее, они занимают существенную часть общего времени работы. Кроме того, при авторегрессионной генерации нужна лишь верхнетреугольная часть матрицы Attention, и вместо того, чтобы считать ее, а затем занулять, ее просто не считают. Вот так вот!
Благодаря перечисленным выше нововведениям удается добиться ускорения 2-3x.
Параллелизм
Flash Attention-1 параллелизует вычисления по размеру батча и числу голов в трансформере, но если батч не слишком большой или трансформер не очень огромный, то многие streaming multiprocessors (SM) простаивают. И чтобы не оставлять их без дела, предлагается паралеллизовывать вычисления и по длине последовательности. На прямом проходе ряды матрицы Attention можно считать независимо, а на обратном проходе - колонки. И каждый поток обрабатывает свой токен. Кроме того, для уменьшения коммуникации между варпами (группами потоков), оказывается целесообразным держать куски матриц ключей (Key) и значений (Values) общими для групп поток, а Query свою на варп (в Flash Attention-1 было наоборот). Уменьшение количество операций чтения/записи приводит к дополнительному ускорению.
Результаты
Flash-Attention-2 сравнивается с Flash-Attention из оригинального репозитория, реализации на triton и xformers. Для замеров рассматривают последовательности длиной от 512 до 16k токенов, и слой attention со скрытой размерностью 2048 (64 или 128 голов).
FlashAttention-2 в 1.3-1.5x быстрее на прямом проходе, и до 2x быстрее на обратном проходе по сравнению с Flash-Attention - 1 (особенно велик выигрыш при использовании causal mask). Flash-Attention - 2 использует до 72% теоретической производительности A100. На H100 разница еще заметнее.
Выводы
Данная история поучительна тем, что одна и та же математическая операция в зависимости от реализации, может выполняться принципиально разное время. Замечательный пример того, что насколько учет особенностей железа, время работы различных компонент, сильных и слабых сторон ускорителя вычислений важен при проектировании алгоритмов.
В работе Flash Attention 2 по существу еще слегка подкрутили процедуру вычисления и повысили степень параллелизма самой операций.
Алгоритм вычисления
Автор заметил, что операции, не являющиеся матричным умножением, выполняются куда медленнее (в 16 раз), чем матричные умножения, потому переписал алгоритм так, чтобы уменьшить их количество. Казалось бы, их количество невелико, но тем не менее, они занимают существенную часть общего времени работы. Кроме того, при авторегрессионной генерации нужна лишь верхнетреугольная часть матрицы Attention, и вместо того, чтобы считать ее, а затем занулять, ее просто не считают. Вот так вот!
Благодаря перечисленным выше нововведениям удается добиться ускорения 2-3x.
Параллелизм
Flash Attention-1 параллелизует вычисления по размеру батча и числу голов в трансформере, но если батч не слишком большой или трансформер не очень огромный, то многие streaming multiprocessors (SM) простаивают. И чтобы не оставлять их без дела, предлагается паралеллизовывать вычисления и по длине последовательности. На прямом проходе ряды матрицы Attention можно считать независимо, а на обратном проходе - колонки. И каждый поток обрабатывает свой токен. Кроме того, для уменьшения коммуникации между варпами (группами потоков), оказывается целесообразным держать куски матриц ключей (Key) и значений (Values) общими для групп поток, а Query свою на варп (в Flash Attention-1 было наоборот). Уменьшение количество операций чтения/записи приводит к дополнительному ускорению.
Результаты
Flash-Attention-2 сравнивается с Flash-Attention из оригинального репозитория, реализации на triton и xformers. Для замеров рассматривают последовательности длиной от 512 до 16k токенов, и слой attention со скрытой размерностью 2048 (64 или 128 голов).
FlashAttention-2 в 1.3-1.5x быстрее на прямом проходе, и до 2x быстрее на обратном проходе по сравнению с Flash-Attention - 1 (особенно велик выигрыш при использовании causal mask). Flash-Attention - 2 использует до 72% теоретической производительности A100. На H100 разница еще заметнее.
Выводы
Данная история поучительна тем, что одна и та же математическая операция в зависимости от реализации, может выполняться принципиально разное время. Замечательный пример того, что насколько учет особенностей железа, время работы различных компонент, сильных и слабых сторон ускорителя вычислений важен при проектировании алгоритмов.
👍2
Stack More Layers Differently: High-Rank Training Through Low-Rank Updates
[Статья][Код]
Обучение всех параметров больших языков моделей весьма прожорливо по памяти из-за необходимости хранить кроме самой тяжеловесной модели еще и состояния оптимизатора (8 байт на параметр).
LoRA, один из самых ходовых методов PEFT, заключающийся в обучении низкоранговых добавок к весам позволяет сильно сэкономить по памяти, демонстрируя при этом хорошее качество при обучении предобученной модели на downstream задачах. Но низкоранговые представления имеют место при дообучении, в то время как для эффективного предобучения на разнообразных данных желательно использовать все имеющуюся в распоряжении емкость сети - то есть обучение должно быть высокоранговым.
В данной статье авторы предлагают метод последовательного обучения низкоранговых добавок к весам линейных слоев нейронной сети с последующим их слиянием с основными весами. И как утверждается, подобная процедура для достаточно больших сетей (самая большая обученная сеть имеет 350M параметров - сущий пустяк по современным меркам), работает ненамного хуже стандартной полноранговой процедуры обучения.
Метод
Ранг суммы двух и более матриц ограничен сверху суммой рангов матриц. Если низкоранговые матрицы в достаточной мере взаимно независимы, то их сумма может иметь значительно больший ранг чем каждое слагаемое по отдельности. Последовательно обучая низкоранговые добавки возможно в итоге добиться высокорангового изменения весов матрицы, В этом и суть метода.
Однако, чтобы метод заработал, авторам пришлось учесть ряд нюансов и применить пару трюков.
Во-первых, используемый при обучении трансформеров Adam хранит скользящие статистики градиентов, и при переходе к обучению новой низкоранговой добавки, если не предпринимать никаких действий, оптимизация будет проводиться в том же подпространстве, что и у предыдущей LoRA добавки, нивелируя всякий смысл в итеративной процедуре. Для предотвращения такого сценария, авторы зануляют 99% состояний оптимизатора с меньшей абсолютной величиной (почему не все? почему не любую другую долю?) при инициализации новой добавки.
Кроме того, learning rate в момент начала обучения новой добавки зануляется и потом быстро разогревается до примерно того же значения, с которым закончила обучение прошлая добавка (используется cosine annealing learning rate). Без короткой warmup фазы обучение расходится.
Предложенная cтратегия именуется ReLoRA.
[Статья][Код]
Обучение всех параметров больших языков моделей весьма прожорливо по памяти из-за необходимости хранить кроме самой тяжеловесной модели еще и состояния оптимизатора (8 байт на параметр).
LoRA, один из самых ходовых методов PEFT, заключающийся в обучении низкоранговых добавок к весам позволяет сильно сэкономить по памяти, демонстрируя при этом хорошее качество при обучении предобученной модели на downstream задачах. Но низкоранговые представления имеют место при дообучении, в то время как для эффективного предобучения на разнообразных данных желательно использовать все имеющуюся в распоряжении емкость сети - то есть обучение должно быть высокоранговым.
В данной статье авторы предлагают метод последовательного обучения низкоранговых добавок к весам линейных слоев нейронной сети с последующим их слиянием с основными весами. И как утверждается, подобная процедура для достаточно больших сетей (самая большая обученная сеть имеет 350M параметров - сущий пустяк по современным меркам), работает ненамного хуже стандартной полноранговой процедуры обучения.
Метод
Ранг суммы двух и более матриц ограничен сверху суммой рангов матриц. Если низкоранговые матрицы в достаточной мере взаимно независимы, то их сумма может иметь значительно больший ранг чем каждое слагаемое по отдельности. Последовательно обучая низкоранговые добавки возможно в итоге добиться высокорангового изменения весов матрицы, В этом и суть метода.
Однако, чтобы метод заработал, авторам пришлось учесть ряд нюансов и применить пару трюков.
Во-первых, используемый при обучении трансформеров Adam хранит скользящие статистики градиентов, и при переходе к обучению новой низкоранговой добавки, если не предпринимать никаких действий, оптимизация будет проводиться в том же подпространстве, что и у предыдущей LoRA добавки, нивелируя всякий смысл в итеративной процедуре. Для предотвращения такого сценария, авторы зануляют 99% состояний оптимизатора с меньшей абсолютной величиной (почему не все? почему не любую другую долю?) при инициализации новой добавки.
Кроме того, learning rate в момент начала обучения новой добавки зануляется и потом быстро разогревается до примерно того же значения, с которым закончила обучение прошлая добавка (используется cosine annealing learning rate). Без короткой warmup фазы обучение расходится.
Предложенная cтратегия именуется ReLoRA.
👍3
Эксперименты
Авторы обучают семейство декодерных моделей моделей от 60 до 350M (типичный размер языковых моделей в 18-19 году) на данных из C4. Архитектура модели повторяет LLaMA.
Процедура обучения состоит из первоначальной фазы полнорангового обучения (т.е обучения всех параметров модели) в течение 5k шагов и 3 циклов обучения низкоранговых добавок на протяжении тех же 5k шагов (с warmup фазой в 100 шагов при переходе к новой LoRA). Пиковый расход памяти такой же, как и в стандартной процедуре обучения.
В качестве бейзлайнов используются:
◦ Стандартное обучение
◦ Обучение меньшей модели с таким же количеством обучаемых параметров, как с LoRA (Control)
◦ LoRA
Метод ожидаемо бьет LoRA, обладая большей выразительностью, и меньшую сеть с тем же числом обучаемых параметров (за исключением самой маленькой модели), при этом несколько уступая стандартной процедуре обучения.
Авторы анализируют спектральное разложение обученных матриц, и у ReLoRA оно больше напоминает изменение весов при обучении всех параметров (по сравнению с LoRA), хоть все еще заметно отличается.
Ablation показывает, что все компоненты метода важны для приемлемого результата - первичная процедура стандартного обучения, зануление состояний отпимизатора и warmup.
Заключение
Довольно интересный и разумный подход. Применимость его в качестве претрейна, по моему мнению, ограничена, из-за необходимости фазы высорангового обучения в начале, из-за чего большие LLM-ки какое-то время придется обучать на множестве хостов. Основной выигрыш может быть при файнтьюнинге на достаточно больших и разнообразных задачах, где выразительности низкоранговых добавок недостаточно.
Авторы обучают семейство декодерных моделей моделей от 60 до 350M (типичный размер языковых моделей в 18-19 году) на данных из C4. Архитектура модели повторяет LLaMA.
Процедура обучения состоит из первоначальной фазы полнорангового обучения (т.е обучения всех параметров модели) в течение 5k шагов и 3 циклов обучения низкоранговых добавок на протяжении тех же 5k шагов (с warmup фазой в 100 шагов при переходе к новой LoRA). Пиковый расход памяти такой же, как и в стандартной процедуре обучения.
В качестве бейзлайнов используются:
◦ Стандартное обучение
◦ Обучение меньшей модели с таким же количеством обучаемых параметров, как с LoRA (Control)
◦ LoRA
Метод ожидаемо бьет LoRA, обладая большей выразительностью, и меньшую сеть с тем же числом обучаемых параметров (за исключением самой маленькой модели), при этом несколько уступая стандартной процедуре обучения.
Авторы анализируют спектральное разложение обученных матриц, и у ReLoRA оно больше напоминает изменение весов при обучении всех параметров (по сравнению с LoRA), хоть все еще заметно отличается.
Ablation показывает, что все компоненты метода важны для приемлемого результата - первичная процедура стандартного обучения, зануление состояний отпимизатора и warmup.
Заключение
Довольно интересный и разумный подход. Применимость его в качестве претрейна, по моему мнению, ограничена, из-за необходимости фазы высорангового обучения в начале, из-за чего большие LLM-ки какое-то время придется обучать на множестве хостов. Основной выигрыш может быть при файнтьюнинге на достаточно больших и разнообразных задачах, где выразительности низкоранговых добавок недостаточно.