Hacker News Digest

23 августа 2025 г. в 12:29 • gau-nernst.github.io • ⭐ 145 • 💬 32

OriginalHN

#cuda#c++#nvidia#flash-attention#machine-learning#gpu-computing#high-performance-computing

Writing Speed-of-Light Flash Attention for 5090 in CUDA C++

Flash Attention на 5090 в CUDA C++

Цель — научиться писать attention-ядро на CUDA C++, чтобы использовать MXFP8/NVFP4 MMA для sm120, чего нет в Triton.
Код: learn-cuda/07_attention.

Бенчмарк (bs=1, heads=8, q=4096, kv=8192, BF16, 5090@400 W, CUDA 12.9, SOL 209.5 TFLOPS):

ядро TFLOPS %SOL
F.sdpa (Flash) 186.73 89.13
F.sdpa (CuDNN) 203.61 97.19
flash-attn 190.58 90.97
v1 (basic) 142.87 68.20
v2 (swizzle) 181.11 86.45
v3 (2-stage) 189.84 90.62
v4 (ldmatrix.x4) 194.33 92.76
v5 (pipe) 197.74 94.39

Алгоритм Flash Attention 2

Псевдокод:

scale = DIM**-0.5
for b, tile_Q:
    tile_O = 0
    tile_Q = load(Q[b, tile_Q])
    for tile_KV:
        tile_K = load(K[b, tile_KV])
        tile_S = tile_Q @ tile_K.T * scale
        online_softmax(tile_S)  # in-place
        tile_V = load(V[b, tile_KV])
        tile_O += tile_S @ tile_V
    store(O[b, tile_Q])

head_dim=128 помещается в регистры.


v1 — базовая версия

  1. G2S: cp.async.ca.shared.global 128-битными транзакциями.
  2. S2R: ldmatrix для Q, K, V → 8×8 фрагменты.
  3. Softmax online:
    • m = max(m_prev, m_curr)
    • d = d_prev * exp(m_prev - m) + Σ exp(S - m)
    • O = O_prev * (d_prev/d) * exp(m_prev - m) + (exp(S - m)/d) @ V

v2 — swizzled shared memory

  • 128-битные банки → конфликты при 8×8 tile.
  • Swizzle K и V по 32-битным строкам; Q оставляем линейно.
  • +40 % пропускной способности.

v3 — 2-stage pipeline

  • Двойной буфер: пока вычисляем S/P@V, асинхронно грузим следующий KV.
  • cp.async.commit_group() + cp.async.wait_group(1).
  • +5 % к SOL.

v4 — ldmatrix.x4

  • Одна инструкция ldmatrix.x4 загружает 4×8×8 фрагмента K/V за раз.
  • Снижает инструкций на 25 %.
  • +2 % к SOL.

v5 — улучшенный pipeline

  • 3-4 стадии:
    1. prefetch KV
    2. compute S
    3. compute P@V
    4. write-back O
  • __pipeline_wait_prior(N) + __pipeline_commit().
  • +2 % к SOL.

Что дальше

  • Использовать TMA (cp.async.bulk) и NVFP4/MXFP8 MMA.
  • Поддержка head_dim > 128 (FlashMLA).