Writing Speed-of-Light Flash Attention for 5090 in CUDA C++
Flash Attention на 5090 в CUDA C++
Цель — научиться писать attention-ядро на CUDA C++, чтобы использовать MXFP8/NVFP4 MMA для sm120, чего нет в Triton.
Код: learn-cuda/07_attention.
Бенчмарк (bs=1, heads=8, q=4096, kv=8192, BF16, 5090@400 W, CUDA 12.9, SOL 209.5 TFLOPS):
ядро | TFLOPS | %SOL |
---|---|---|
F.sdpa (Flash) | 186.73 | 89.13 |
F.sdpa (CuDNN) | 203.61 | 97.19 |
flash-attn | 190.58 | 90.97 |
v1 (basic) | 142.87 | 68.20 |
v2 (swizzle) | 181.11 | 86.45 |
v3 (2-stage) | 189.84 | 90.62 |
v4 (ldmatrix.x4) | 194.33 | 92.76 |
v5 (pipe) | 197.74 | 94.39 |
Алгоритм Flash Attention 2
Псевдокод:
scale = DIM**-0.5
for b, tile_Q:
tile_O = 0
tile_Q = load(Q[b, tile_Q])
for tile_KV:
tile_K = load(K[b, tile_KV])
tile_S = tile_Q @ tile_K.T * scale
online_softmax(tile_S) # in-place
tile_V = load(V[b, tile_KV])
tile_O += tile_S @ tile_V
store(O[b, tile_Q])
head_dim=128 помещается в регистры.
v1 — базовая версия
- G2S:
cp.async.ca.shared.global
128-битными транзакциями. - S2R:
ldmatrix
для Q, K, V → 8×8 фрагменты. - Softmax online:
m = max(m_prev, m_curr)
d = d_prev * exp(m_prev - m) + Σ exp(S - m)
O = O_prev * (d_prev/d) * exp(m_prev - m) + (exp(S - m)/d) @ V
v2 — swizzled shared memory
- 128-битные банки → конфликты при 8×8 tile.
- Swizzle
K
иV
по 32-битным строкам;Q
оставляем линейно. - +40 % пропускной способности.
v3 — 2-stage pipeline
- Двойной буфер: пока вычисляем S/P@V, асинхронно грузим следующий KV.
cp.async.commit_group()
+cp.async.wait_group(1)
.- +5 % к SOL.
v4 — ldmatrix.x4
- Одна инструкция
ldmatrix.x4
загружает 4×8×8 фрагмента K/V за раз. - Снижает инструкций на 25 %.
- +2 % к SOL.
v5 — улучшенный pipeline
- 3-4 стадии:
- prefetch KV
- compute S
- compute P@V
- write-back O
__pipeline_wait_prior(N)
+__pipeline_commit()
.- +2 % к SOL.
Что дальше
- Использовать TMA (
cp.async.bulk
) и NVFP4/MXFP8 MMA. - Поддержка head_dim > 128 (FlashMLA).
Комментарии (32)
- Пользователи удивлены, что RTX 5090 даёт всего 209 TFLOPS BF16 — менее 10 % от серверного Blackwell B200 (2250 TFLOPS), но при цене ~$30-40 k за B200 соотношение цена/производительность почти сравнялось.
- Обсуждают, что NVIDIA с 4090 и далее искусственно ограничивает тензорные ядра игровых карт для ML-операций FP8/FP16.
- У 5090 выше TDP, чем у 4090, и можно ограничить мощность лишь до 70 % (4090 — до 50 %), что мешает апгрейду для ML-станций.
- Появились вопросы о поддержке Flash Attention на 5090/5080 и о нативной компиляции под Blackwell в PyTorch 2.7.
- Участники спорят, стоит ли вкладываться в Triton, если нужны фирменные типы NVFP4/MXFP8, которых там пока нет.
Show HN: Luminal – Open-source, search-based GPU compiler
luminal — библиотека для глубокого обучения, работающая «со скоростью света».
Основное
- Язык: Rust
- Цель: максимально быстрое вычисление градиентов и обучение нейросетей.
- Подход: компиляция вычислительного графа в высокооптимизированный нативный код (LLVM).
Возможности
- Автоматическое дифференцирование.
- JIT-компиляция графов.
- Поддержка CPU и GPU (CUDA).
- Минимальные накладные расходы: нет Python-интерпретатора и лишних библиотек.
Примеры
let x = Cpu::tensor([1.0, 2.0, 3.0]);
let y = x.relu().sum();
let g = y.backward(); // градиент за наносекунды
Установка
cargo add luminal
Статус
Проект в активной разработке; API может меняться.
Комментарии (53)
- Luminal — это ML-фреймворк, который вместо ручных правил формулирует оптимизацию как поиск по огромному пространству возможных ядер (tiling, потоки, инструкции и т.д.) с помощью e-graphs.
- Сейчас на M-серии MacBook Llama-3 8B Q8 выдаёт 15-25 ток/с; это ниже llama.cpp, но команда строит трекер производительности и продолжает улучшать поиск.
- Поиск ограничен 12 базовыми линейно-алгебраическими операциями, что делает задачу похожей на «superoptimisation» и позволяет добавлять аппаратно-специфичные инструкции (tensor cores, PTX/ASM) без роста frontend.
- Для оценки качества ядра используется реальное время выполнения на целевом железе; масштабировать планируют распараллеленным профилированием на кластерах GPU.
- Отличие от TVM/tinygrad — единое пространство поиска, включающее как параметры тайлинга, так и алгебраические преобразования (например, softmax → flash-attention).