We reverse-engineered Flash Attention 4
Новая версия Flash Attention 4 оптимизирована под архитектуру Blackwell от Nvidia и обещает ~20% прирост скорости по сравнению с предыдущим рекордсменом — закрытыми ядрами внимания в библиотеке cudnn. Хотя официального отчёта нет, исходный код уже доступен, что позволило разобрать его устройство. Главное изменение — не математические трюки (вроде быстрых приближённых экспонент или эффективного онлайн-softmax), а сложная асинхронная конвейеризация операций, напоминающая принципы параллельного программирования из высокопроизводительных систем вроде баз данных или веб-серверов.
Архитектура FA4 построена вокруг обработки «тайлов» — блоков данных, которые потоково считываются из глобальной памяти GPU. Один экземпляр ядра обрабатывает два тайла запросов, последовательно сканируя все ключи и значения, чтобы вычислить взвешенные выходные данные. Это напоминает векторized-сканирование в СУБД. Масштабирование достигается за счёт массового параллельного запуска таких программ по модели «одна программа — много данных». Подход требует глубокой асинхронности и эффективного использования warp-ов, но остаётся интуитивно понятным для инженеров, работавших с конкурентными системами.