Writing Speed-of-Light Flash Attention for 5090 in CUDA C++
Flash Attention на 5090 в CUDA C++
Цель — научиться писать attention-ядро на CUDA C++, чтобы использовать MXFP8/NVFP4 MMA для sm120, чего нет в Triton.
Код: learn-cuda/07_attention.
Бенчмарк (bs=1, heads=8, q=4096, kv=8192, BF16, 5090@400 W, CUDA 12.9, SOL 209.5 TFLOPS):
ядро | TFLOPS | %SOL |
---|---|---|
F.sdpa (Flash) | 186.73 | 89.13 |
F.sdpa (CuDNN) | 203.61 | 97.19 |
flash-attn | 190.58 | 90.97 |
v1 (basic) | 142.87 | 68.20 |
v2 (swizzle) | 181.11 | 86.45 |
v3 (2-stage) | 189.84 | 90.62 |
v4 (ldmatrix.x4) | 194.33 | 92.76 |
v5 (pipe) | 197.74 | 94.39 |
Алгоритм Flash Attention 2
Псевдокод:
scale = DIM**-0.5
for b, tile_Q:
tile_O = 0
tile_Q = load(Q[b, tile_Q])
for tile_KV:
tile_K = load(K[b, tile_KV])
tile_S = tile_Q @ tile_K.T * scale
online_softmax(tile_S) # in-place
tile_V = load(V[b, tile_KV])
tile_O += tile_S @ tile_V
store(O[b, tile_Q])
head_dim=128 помещается в регистры.
v1 — базовая версия
- G2S:
cp.async.ca.shared.global
128-битными транзакциями. - S2R:
ldmatrix
для Q, K, V → 8×8 фрагменты. - Softmax online:
m = max(m_prev, m_curr)
d = d_prev * exp(m_prev - m) + Σ exp(S - m)
O = O_prev * (d_prev/d) * exp(m_prev - m) + (exp(S - m)/d) @ V
v2 — swizzled shared memory
- 128-битные банки → конфликты при 8×8 tile.
- Swizzle
K
иV
по 32-битным строкам;Q
оставляем линейно. - +40 % пропускной способности.
v3 — 2-stage pipeline
- Двойной буфер: пока вычисляем S/P@V, асинхронно грузим следующий KV.
cp.async.commit_group()
+cp.async.wait_group(1)
.- +5 % к SOL.
v4 — ldmatrix.x4
- Одна инструкция
ldmatrix.x4
загружает 4×8×8 фрагмента K/V за раз. - Снижает инструкций на 25 %.
- +2 % к SOL.
v5 — улучшенный pipeline
- 3-4 стадии:
- prefetch KV
- compute S
- compute P@V
- write-back O
__pipeline_wait_prior(N)
+__pipeline_commit()
.- +2 % к SOL.
Что дальше
- Использовать TMA (
cp.async.bulk
) и NVFP4/MXFP8 MMA. - Поддержка head_dim > 128 (FlashMLA).