Skip to content

📚200+ Tensor/CUDA Cores Kernels, ⚡️flash-attn-mma, ⚡️hgemm with WMMA, MMA and CuTe (98%~100% TFLOPS of cuBLAS/FA2 🎉🎉).

License

Notifications You must be signed in to change notification settings

DefTruth/CUDA-Learn-Notes

Folders and files

NameName
Last commit message
Last commit date

Latest commit

1c9866c · Mar 22, 2025
Jan 15, 2025
Jan 15, 2025
Mar 22, 2025
Mar 4, 2025
Mar 4, 2025
Dec 15, 2024
Dec 15, 2024
Mar 22, 2025
Dec 15, 2024
Mar 4, 2025
Dec 15, 2024
Mar 22, 2025

Repository files navigation

📚 Modern CUDA Learn Notes with PyTorch for Beginners 🐑

📚200+ CUDA Kernels | 📚100+ Blogs | ⚡️HGEMM MMA | ⚡️FA-2 MMA

📚 Modern CUDA Learn Notes with PyTorch for Beginners: It includes Tensor/CUDA Cores, TF32/F16/BF16/F8, 📖200+ CUDA Kernels🔥🔥(Easy -> Hard++) with PyTorch bindings, 📖100+ LLM/VLM/CV/CUDA/CuTe🔥 blogs, 📖toy-hgemm⚡️⚡️ which can achieve 98%~100% performance of cuBLAS, and 📖flash-attention-mma⚡️⚡️ using Tensor Cores with pure MMA PTX. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉

📖 News 🔥🔥

  • [2024-12-02]: HGEMM MMA kernels has been refactored into 🤖hgemm-mma: ⚡️Write HGEMM from scratch using Tensor Cores with WMMA, MMA and CuTe API, achieve peak⚡️ performance.

📖 HGEMM Benchmark 🎉🎉

Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores algorithm, the HGEMM (WMMA/MMA/CuTe) in this repo (blue🔵) can achieve 98%~100% of its (orange🟠) performance. Please check toy-hgemm library⚡️⚡️ or hgemm-mma⚡️⚡️ repo for more details.

toy-hgemm-library

📚Feature 📚Feature 📚Feature 📚Feature
✔️CUDA/Tensor Cores ✔️Loop over K ✔️Tile Block(BMxBK) ✔️Tile Threads(T 8x8)
✔️WMMA(m16n16k16) ✔️MMA(m16n8k16) ✔️Pack LDST(128 bits) ✔️SMEM Padding
✔️Copy Async ✔️Tile MMAs ✔️Tile Warps ✔️Multi Stages(2~4)
✔️Register Double Buffers ✔️Block Swizzle ✔️Warp Swizzle ✔️SMEM Swizzle(CuTe/MMA)
✔️Collective Store(Shfl) ✔️Layout NN ✔️Layout TN ✔️SGEMM FP32/TF32

📖 FA2-MMA Benchmark 🎉🎉

I have also implemented FlashAttention-2 using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Shared KV SMEM, Fully Shared QKV SMEM, Prefetch Q s2r, Prefetch K/V g2s, QKV Fine-grained Tiling, Collective Store, etc. Please refer to flash-attention-mma⚡️⚡️ for more details.

flash-attn-mma

📚Feature 📚Feature 📚Feature 📚Feature
✔️Tensor Cores ✔️Loop over N/D ✔️Tile Block(Br, Bc) ✔️MMA(m16n8k16)
✔️Pack LDST(128 bits) ✔️SMEM Swizzle/Padding ✔️Copy Async ✔️Tile MMAs
✔️Tile Warps ✔️Multi Stages(1/2) ✔️Collective Store(Shfl) ✔️Split KV/Q
✔️Shared QKV SMEM ✔️Prefetch Q s2r ✔️Prefetch KV g2s ✔️QKV Fine-grained Tiling

Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192, D <= 64) it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, 📚 Split Q + Fully Shared QKV SMEM method can achieve 55 TFLOPS (D=64) that almost ~1.5x 🎉 faster than FA2. On NVIDIA L20, 🤖ffpa-attn-mma method can achieve 104 TFLOPS (D=512) that almost ~1.8x 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16/F32, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)

Algorithm (B,H,N,D) RTX 3080 Laptop L20 RTX 4090
FlashAttention-2 (1,8,8192,64) 37 TFLOPS 100 TFLOPS 145 TFLOPS
share-qkv+stage2 (1,8,8192,64) 55 TFLOPS 99 TFLOPS 221 TFLOPS
FlashAttention-2 (1,48,8192,64) 37 TFLOPS 109 TFLOPS 163 TFLOPS
share-qkv+stage2 (1,48,8192,64) 48 TFLOPS 107 TFLOPS 224 TFLOPS
SDPA(EFFICIENT ATTENTION) (1,48,8192,512) 16 TFLOPS 58 TFLOPS 85 TFLOPS
🤖ffpa-attn-mma (1,48,8192,512) 39 TFLOPS 104 TFLOPS 200 TFLOPS
Precision Errors vs FA2/SDPA / max: < ~1e-3 min: ~0.0 mean: < ~1e-5

The Split KV and Split Q implementations have been carried out in flash-attention-mma⚡️⚡️ for performance comparison. The Split KV method, which involves splitting all QKV across MMA (Warps), is slower than Split Q method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).

  • 📚 Split KV (Basic, FlashAttention-1)
// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.
// case: The layout of 8 MMA(2x4)  [after] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 32x2,32x2=64x64: 
// |  [64,64]  |    warp_KV 0    |    warp_KV 1    |    warp_KV 2    |    warp_KV 3    |
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_kv_kernel(half* Q, half* K, half* V, half* O, ...);
  • 📚 Split Q (Faster, FlashAttention-2)
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
// in order to reduce the comm between warps via smem and warp shuffle.
// case: MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
// |   64x64   |      warp_KV 0       |
// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_kernel(half* Q, half* K, half* V, half* O, ...);
  • 📚 Split Q + Shared KV SMEM (1/2 SRAM vs FA2)
// K, V shared the same shared memory, improve block occupancy.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, half* K, half* V, half* O, ...);
  • 📚 Split Q + Fully Shared QKV SMEM (1/4 SRAM vs FA2)
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
// and reduce Q SMEM IO-Access.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
  • 📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM, Headdim -> 1024)
// Fine-grained tiling at the MMA level for Q@K^T results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
  • 📚 Split Q + Fully QKV Fine-grained Tiling (O(2xBrx16)~O(1) SRAM vs FA2 O(4xBrxd) SRAM)
// Fine-grained tiling at the MMA level for all Q@K^T and P@V results in a constant SRAM usage of
// Br * 16 or Bc * 16 for Q, K, V, leading to an overall SRAM complexity of O(Br * 16). Consequently,
// this approach allows us to run faster than SDPA w or w/o MMA Acc F32. 
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qkv_kernel(half* Q, half* K, half* V, half* O, ...);

💡NOTE: 📚Split Q + Fully QKV Fine-grained Tiling has been refactored into 🤖ffpa-attn-mma.

©️Citations🎉🎉

@misc{CUDA-Learn-Notes@2024,
  title={CUDA-Learn-Notes: A Modern CUDA Learn Notes with PyTorch for Beginners},
  url={https://github.com/DefTruth/CUDA-Learn-Notes},
  note={Open-source software available at https://github.com/DefTruth/CUDA-Learn-Notes},
  author={DefTruth etc},
  year={2024}
}

📖 200+ CUDA Kernels 🔥🔥 (Easy -> Hard++) (©️back👆🏻)

The kernels listed here will guide you through a step-by-step progression, ranging from easy to very challenging topics. The workflow for each topic will be as follows: custom CUDA kernel implementation -> PyTorch Python bindings -> Run tests. 👉TIPS: * = Tensor Cores (WMMA, MMA, CuTe), otherwise, CUDA Cores; / = not supported; ✔️ = supported; = TODO. Contents are listed as follows:

📚 Easy and 📚 Medium sections cover operations such as element-wise, mat_trans, warp/block reduce, nms, relu, gelu, swish, layer-norm, rms-norm, online-softmax, dot-prod, embedding and basic usage for FP32, FP16, BF16 and FP8 . 📚 Hard, 📚 Hard+ and 📚 Hard++ sections delve deeper into advanced topics, primarily focusing on operations like sgemv, sgemm, hgemv, hgemm and flash-attention. These sections also provide numerous kernels implemented using Tensor Cores with pure MMA PTX.

📚 Easy ⭐️ & Medium ⭐️⭐️ (©️back👆🏻)

📖 CUDA Kernel 📖 Elem DType 📖 Acc DType 📖 Docs 📖 Level
✔️ elementwise_f32 f32 / link ⭐️
✔️ elementwise_f32x4 f32 / link ⭐️
✔️ elementwise_f16 f16 / link ⭐️
✔️ elementwise_f16x2 f16 / link ⭐️
✔️ elementwise_f16x8 f16 / link ⭐️
✔️ elementwise_f16x8_pack f16 / link ⭐️⭐️
✔️ histogram_i32 i32 / link ⭐️
✔️ histogram_i32x4 i32 / link ⭐️
✔️ sigmoid_f32 f32 / link ⭐️
✔️ sigmoid_f32x4 f32 / link ⭐️
✔️ sigmoid_f16 16 / link ⭐️
✔️ sigmoid_f16x2 f16 / link ⭐️
✔️ sigmoid_f16x8 f16 / link ⭐️
✔️ sigmoid_f16x8_pack f16 / link ⭐️⭐️
✔️ relu_f32 f32 / link ⭐️
✔️ relu_f32x4 f32 / link ⭐️
✔️ relu_f16 f16 / link ⭐️
✔️ relu_f16x2 f16 / link ⭐️
✔️ relu_f16x8 f16 / link ⭐️
✔️ relu_f16x8_pack f16 / link ⭐️⭐️
✔️ elu_f32 f32 / link ⭐️
✔️ elu_f32x4 f32 / link ⭐️
✔️ elu_f16 f16 / link ⭐️
✔️ elu_f16x2 f16 / link ⭐️
✔️ elu_f16x8 f16 / link ⭐️
✔️ elu_f16x8_pack f16 / link ⭐️⭐️
✔️ gelu_f32 f32 / link ⭐️
✔️ gelu_f32x4 f32 / link ⭐️
✔️ gelu_f16 f16 / link ⭐️
✔️ gelu_f16x2 f16 / link ⭐️
✔️ gelu_f16x8 f16 / link ⭐️
✔️ gelu_f16x8_pack f16 / link ⭐️⭐️
✔️ swish_f32 f32 / link ⭐️
✔️ swish_f32x4 f32 / link ⭐️
✔️ swish_f16 f16 / link ⭐️
✔️ swish_f16x2 f16 / link ⭐️
✔️ swish_f16x8 f16 / link ⭐️
✔️ swish_f16x8_pack f16 / link ⭐️⭐️
✔️ hardswish_f32 f32 / link ⭐️
✔️ hardswish_f32x4 f32 / link ⭐️
✔️ hardswish_f16 f16 / link ⭐️
✔️ hardswish_f16x2 f16 / link ⭐️
✔️ hardswish_f16x8 f16 / link ⭐️
✔️ hardswish_f16x8_pack f16 / link ⭐️⭐️
✔️ hardshrink_f32 f32 / link ⭐️
✔️ hardshrink_f32x4 f32 / link ⭐️
✔️ hardshrink_f16 f16 / link ⭐️
✔️ hardshrink_f16x2 f16 / link ⭐️
✔️ hardshrink_f16x8 f16 / link ⭐️
✔️ hardshrink_f16x8_pack f16 / link ⭐️⭐️
✔️ embedding_f32 f32 / link ⭐️
✔️ embedding_f32x4 f32 / link ⭐️
✔️ embedding_f32x4_pack f32 / link ⭐️
✔️ embedding_f16 f16 / link ⭐️
✔️ embedding_f16x2 f16 / link ⭐️
✔️ embedding_f16x8 f16 / link ⭐️
✔️ embedding_f16x8_pack f16 / link ⭐️⭐️
✔️ mat_trans_f32_col2row{2d} f32 / link ⭐️
✔️ mat_trans_f32_row2col{2d} f32 / link ⭐️
✔️ mat_trans_f32_diagonal2d f32 / link ⭐️⭐️
✔️ mat_trans_f32x4_col2row{2d} f32 / link ⭐️⭐️
✔️ mat_trans_f32x4_row2col{2d} f32 / link ⭐️⭐️
✔️ warp_reduce_{all} all all link ⭐️⭐️
✔️ block_all_reduce_f32_f32 f32 f32 link ⭐️⭐️
✔️ block_all_reduce_f32x4_f32 f32 f32 link ⭐️⭐️
✔️ block_all_reduce_f16_f16 f16 f16 link ⭐️⭐️
✔️ block_all_reduce_f16_f32 f16 f32 link ⭐️⭐️
✔️ block_all_reduce_f16x2_f16 f16 f16 link ⭐️⭐️
✔️ block_all_reduce_f16x2_f32 f16 f32 link ⭐️⭐️
✔️ block_all_reduce_f16x8_pack_f16 f16 f16 link ⭐️⭐️
✔️ block_all_reduce_f16x8_pack_f32 f16 f32 link ⭐️⭐️
✔️ block_all_reduce_bf16_bf16 bf16 bf16 link ⭐️⭐️
✔️ block_all_reduce_bf16_f32 bf16 f32 link ⭐️⭐️
✔️ block_all_reduce_bf16x2_bf16 bf16 bf16 link ⭐️⭐️
✔️ block_all_reduce_bf16x2_f32 bf16 f32 link ⭐️⭐️
✔️ block_all_reduce_bf16x8_pack_bf16 bf16 bf16 link ⭐️⭐️
✔️ block_all_reduce_bf16x8_pack_f32 bf16 f32 link ⭐️⭐️
✔️ block_all_reduce_fp8_e4m3_f16 fp8_e4m3 f16 link ⭐️⭐️⭐️
✔️ block_all_reduce_fp8_e5m2_f16 fp8_e5m2 f16 link ⭐️⭐️⭐️
✔️ block_all_reduce_fp8_e4m3x16_pack_f16 fp8_e4m3 f16 link ⭐️⭐️⭐️
✔️ block_all_reduce_fp8_e5m2x16_pack_f16 fp8_e5m2 f16 link ⭐️⭐️⭐️
✔️ block_all_reduce_i8_i32 i8 i32 link ⭐️⭐️
✔️ block_all_reduce_i8x16_pack_i32 i8 i32 link ⭐️⭐️
✔️ dot_product_f32 f32 f32 link ⭐️⭐️
✔️ dot_product_f32x4 f32 f32 link ⭐️⭐️
✔️ dot_product_f16_f32 f16 f32 link ⭐️⭐️
✔️ dot_product_f16x2_f32 f16 f32 link ⭐️⭐️
✔️ dot_product_f16x8_pack_f32 f16 f32 link ⭐️⭐️
✔️ softmax_f32(fence) f32 f32 link ⭐️⭐️
✔️ softmax_f32x4(fence) f32 f32 link ⭐️⭐️
✔️ softmax_f32 f32 f32 link ⭐️⭐️
✔️ softmax_f32x4 f32 f32 link ⭐️⭐️
✔️ safe_softmax_f32 f32 f32 link ⭐️⭐️
✔️ safe_softmax_f32x4 f32 f32 link ⭐️⭐️
✔️ safe_softmax_f16_f32 f16 f32 link ⭐️⭐️
✔️ safe_softmax_f16x2_f32 f16 f32 link ⭐️⭐️
✔️ safe_softmax_f16x8_pack_f32 f16 f32 link ⭐️⭐️
✔️ online_safe_softmax_f32 f32 f32 link ⭐️⭐️
✔️ online_safe_softmax_f32x4_pack f32 f32 link ⭐️⭐️
✔️ rope_f32 f32 f32 link ⭐️⭐️
✔️ rope_f32x4_pack f32 f32 link ⭐️⭐️
✔️ layer_norm_f32 f32 f32 link ⭐️⭐️
✔️ layer_norm_f32x4 f32 f32 link ⭐️⭐️
✔️ layer_norm_f16_f16 f16 f16 link ⭐️⭐️
✔️ layer_norm_f16x2_f16 f16 f16 link ⭐️⭐️
✔️ layer_norm_f16x8_f16 f16 f16 link ⭐️⭐️
✔️ layer_norm_f16x8_pack_f16 f16 f16 link ⭐️⭐️
✔️ layer_norm_f16x8_pack_f32 f16 f32 link ⭐️⭐️
✔️ layer_norm_f16_f32 f16 f32 link ⭐️⭐️
✔️ rms_norm_f32 f32 f32 link ⭐️⭐️
✔️ rms_norm_f32x4 f32 f32 link ⭐️⭐️
✔️ rms_norm_f16_f16 f16 f16 link ⭐️⭐️
✔️ rms_norm_f16x2_f16 f16 f16 link ⭐️⭐️
✔️ rms_norm_f16x8_f16 f16 f16 link ⭐️⭐️
✔️ rms_norm_f16x8_f32 f16 f32 link ⭐️⭐️
✔️ rms_norm_f16x8_pack_f16 f16 f16 link ⭐️⭐️
✔️ rms_norm_f16x8_pack_f32 f16 f32 link ⭐️⭐️
✔️ rms_norm_f16_f32 f16 f32 link ⭐️⭐️
✔️ nms_f32 f32 / link ⭐️⭐️
✔️ notes v1(deprecated) f32 f32 / ⭐️⭐️
✔️ How to use nsys/ncu(timeline/ptx/sass) / / link ⭐️⭐️

📚 Hard ⭐⭐⭐️ (©️back👆🏻)

📖 CUDA Kernel 📖 Elem DType 📖 Acc DType 📖 Docs 📖 Level
✔️ sgemv_k32_f32 f32 f32 link ⭐️⭐️⭐️
✔️ sgemv_k128_f32x4 f32 f32 link ⭐️⭐️⭐️
✔️ sgemv_k16_f32 f32 f32 link ⭐️⭐️⭐️
✔️ hgemv_k32_f16 f16 f16 link ⭐️⭐️⭐️
✔️ hgemv_k128_f16x4 f16 f16 link ⭐️⭐️⭐️
✔️ hgemv_k16_f16 f16 f16 link ⭐️⭐️⭐️
✔️ sgemm_naive_f32 f32 f32 link ⭐️⭐️
✔️ sgemm_sliced_k_f32 f32 f32 link ⭐️⭐️⭐️
✔️ sgemm_t_8x8_sliced_k_f32x4 f32 f32 link ⭐️⭐️⭐️
✔️ sgemm_t_8x8_sliced_k...bcf f32 f32 link ⭐️⭐️⭐️
✔️ sgemm_t_8x8_sliced_k...dbuf f32 f32 link ⭐️⭐️⭐️
✔️ sgemm_t_8x8_sliced_k16...dbuf f32 f32 link ⭐️⭐️⭐️
✔️ sgemm_t_8x8_sliced_k16...async f32 f32 link ⭐️⭐️⭐️
✔️ sgemm_wmma_m16n16k8...stages* tf32 f32 link ⭐️⭐️⭐️
✔️ sgemm_wmma_m16n16k8...swizzle* tf32 f32 link ⭐️⭐️⭐️
✔️ hgemm_naive_f16 f16 f16 link ⭐️⭐️
✔️ hgemm_sliced_k_f16 f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_t_8x8_sliced_k_f16x4 f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_t_8x8_sliced_k_f16x4_pack f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_t_8x8_sliced_k_f16x8_pack f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_t_8x8_sliced_k...dbuf f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_t_8/16x8...k16/32...dbuf f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_t_8/16x8...k16/32...async f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_wmma_m16n16k16...naive* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_wmma_m16n16k16...mma4x2* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_wmma_m16n16k16...mma4x4* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_wmma_m16n16k16...dbuf* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_wmma_m32n8k16....dbuf* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_wmma_m16n16k16...stages* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_wmma_m16n16k16...swizzle* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_mma_m16n8k16...naive* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_mma_m16n8k16...mma2x4* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_mma_m16n8k16...stages* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_mma_m16n8k16...swizzle* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_mma_m16n8k16...swizzle{smem}* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_mma_m16n8k16...swizzle{tn}{smem}* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_mma_stages_swizzle{smem}...cute* f16 f16 link ⭐️⭐️⭐️
✔️ hgemm_mma_cublas* f16 f16 link ⭐️⭐️

📚 Hard+ ⭐️⭐️⭐️⭐️ & Hard++ ⭐️⭐️⭐️⭐️⭐️ (©️back👆🏻)

  • 📚 FlashAttention-2 MMA (MMA Acc F32/F16, swizzle, QKV smem share, fine-grained tiling, etc.🎉)
📖 CUDA Kernel 📖 Elem DType 📖 Acc DType 📖 Docs 📖 Level
✔️ How to implement MMA smem swizzle* f16 f16 link ⭐️⭐️⭐️
✔️ flash_attn_mma_stages_split_kv* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma_stages_split_q* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma_stages...shared_kv* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma_stages...shared_qkv* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma_stages...tiling_qk* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma_stages...tiling_qkv* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma_stages...shared_kv{f32}* f16 f32 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma_stages...shared_qkv{f32}* f16 f32 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma_stages...tiling_qk{f32}* f16 f32 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma_stages...tiling_qkv{f32}* f16 f32 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...shared_kv{f32}{rr}* f16 f32 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...shared_qkv{f32}{rr}* f16 f32 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...shared_kv_swizzle{q}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...shared_kv_swizzle{qk}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...shared_kv_swizzle{qkv}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...shared_qkv_swizzle{q}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...shared_qkv_swizzle{qk}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...shared_qkv_swizzle{qkv}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...tiling_qk_swizzle{q}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...tiling_qk_swizzle{qk}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...tiling_qk_swizzle{qkv}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...tiling_qkv_swizzle{q}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...tiling_qkv_swizzle{qk}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn_mma...tiling_qkv_swizzle{qkv}* f16 f16 link ⭐️⭐️⭐️⭐️
✔️ flash_attn...tiling_qkv_swizzle{q}{f32}* f16 f32 link ⭐️⭐️⭐️⭐️
✔️ flash_attn...tiling_qkv_swizzle{qk}{f32}* f16 f32 link ⭐️⭐️⭐️⭐️
✔️ flash_attn...tiling_qkv_swizzle{qkv}{f32}* f16 f32 link ⭐️⭐️⭐️⭐️

💡NOTE: rr: means reduce registers usage (for d>128); f32: means MMA accumulate with FP32 dtype, otherwise, FP16. softmax Acc dtype is always be FP32 for high precision; swizzle: now, only support smem swizzle for MMA.

  • 📚 FFPA Attention MMA (1.8x~3x🎉faster vs SDPA EA, D > 256, FA2 not supported)
📖 CUDA Kernel 📖 Elem DType 📖 Acc DType 📖 Docs 📖 Level
✔️ ffpa_mma_stages_split_q_L1_F16F16F16 f16 f16 link ⭐️⭐️⭐️⭐️
✔️ ffpa_mma_stages_split_q_L1_F16F16F32 f16 f32 link ⭐️⭐️⭐️⭐️
✔️ ffpa_mma_stages_split_q_L1_mixed_acc f16 QK f32, PV f16 link ⭐️⭐️⭐️⭐️
⚠️ ffpa_mma_stages_split_q_L2_F16F16F16 f16 f16 link ⭐️⭐️⭐️⭐️
⚠️ ffpa_mma_stages_split_q_L2_F16F16F32 f16 f32 link ⭐️⭐️⭐️⭐️
⚠️ ffpa_mma_stages_split_q_L2_mixed_acc f16 QK f32, PV f16 link ⭐️⭐️⭐️⭐️
⚠️ ffpa_mma_stages_split_q_L3_F16F16F16 f16 f16 link ⭐️⭐️⭐️⭐️
⚠️ ffpa_mma_stages_split_q_L3_F16F16F32 f16 f32 link ⭐️⭐️⭐️⭐️
⚠️ ffpa_mma_stages_split_q_L3_mixed_acc f16 QK f32, PV f16 link ⭐️⭐️⭐️⭐️

💡NOTE: 🤖ffpa-attn-mma: 📚FFPA - Yet another Faster Flash Prefill Attention with O(1)🎉SRAM complexity for headdim > 256, 1.8x~3x🎉faster than SDPA EA: 📈L20 ~1.9x↑🎉, 📈 A30 ~1.8x↑🎉, 📈3080 ~2.9x↑🎉, 📈4090 ~2.1x↑🎉.

📖 100+ LLM/VLM/CV/CUDA/CuTe Tech Blogs

📚 大模型|多模态|Diffusion|推理优化 (本人作者) (©️back👆🏻)

📖 类型-标题 📖 作者 📖 推荐
[vLLM实践]📚vLLM + DeepSeek-R1 671B 多机部署及修Bug笔记 @DefTruth ⭐️⭐️⭐⭐️
[Attention优化]📚FFPA(Split-D): FA2无限HeadDim扩展,2x↑🎉 vs SDPA EA @DefTruth ⭐️⭐️⭐⭐️
[CUDA基础][开篇]📖CUDA-Learn-Notes: v3.0 大升级-面试刷题不迷路 @DefTruth ⭐️⭐️⭐⭐️
[分布式训推][张量/序列并行]📖图解DeepSpeed-Ulysses&Megatron-LM TP/SP @DefTruth ⭐️⭐️
[VLM推理优化][InternVL系列]📖InternLM2/.../InternVL1.5系列笔记: 核心点解析 @DefTruth ⭐️⭐️
[LLM推理优化][TensorRT-LLM][5w字]📖TensorRT-LLM部署调优-指北 @DefTruth ⭐️⭐️⭐️
[LLM推理优化][KV Cache优化]📖GQA/YOCO/CLA/MLKV: 层内和层间KV Cache共享 @DefTruth ⭐️⭐️
[LLM推理优化][Prefill优化]📖图解vLLM Prefix Prefill Triton Kernel @DefTruth ⭐️⭐️⭐️
[LLM推理优化][Prefill优化][万字]📖图解vLLM Automatic Prefix Caching: TTFT优化 @DefTruth ⭐️⭐️⭐️
[LLM推理优化][Attention优化]📖图解:从Online-Softmax到FlashAttention V1/V2/V3 @DefTruth ⭐️⭐️⭐️
[LLM推理优化][Decoding优化]📖原理&图解FlashDecoding/FlashDecoding++ @DefTruth ⭐️⭐️
[VLM推理优化][LLaVA系列]📖CLIP/LLaVA/LLaVA1.5/VILA笔记: 核心点解析 @DefTruth ⭐️⭐️
[LLM推理优化][Attention优化][万字]📖TensorRT MHA/Myelin vs FlashAttention-2 @DefTruth ⭐️⭐️⭐️
[LLM推理优化][PTX汇编]📖CUDA 12 PTX汇编: PRMT指令详解-通用模式 @DefTruth ⭐️
[LLM推理优化][PTX汇编]📖CUDA 12 PTX汇编: LOP3指令详解 @DefTruth ⭐️
[LLM推理优化][CUDA][3w字]📖高频面试题汇总-大模型手撕CUDA @DefTruth ⭐️⭐️⭐️
[LLM推理优化][Weight Only]📖WINT8/4-(00): 通俗易懂讲解-快速反量化算法 @DefTruth ⭐️⭐️
[LLM推理优化][Weight Only]📖WINT8/4-(01): PRMT指令详解及FT源码解析 @DefTruth ⭐️⭐️
[LLM推理优化][Weight Only]📖WINT8/4-(02): 快速反量化之INT8转BF16 @DefTruth ⭐️⭐️
[LLM推理优化][Weight Only]📖WINT8/4-(03): LOP3指令详解及INT4转FP16/BF16 @DefTruth ⭐️⭐️
[LLM推理优化][LLM Infra整理]📖100+篇: 大模型推理各方向新发展整理 @DefTruth ⭐️⭐️
[LLM推理优化][LLM Infra整理]📖30+篇: LLM推理论文集-500页PDF @DefTruth ⭐️⭐️
[LLM推理优化][LLM Infra整理]📖FlashDecoding++: 比FlashDecoding还要快! @DefTruth ⭐️
[LLM推理优化][LLM Infra整理]📖TensorRT-LLM开源,TensorRT 9.1也来了 @DefTruth ⭐️
[LLM推理优化][LLM Infra整理]📖20+篇: LLM推理论文集-300页PDF @DefTruth ⭐️⭐️
[LLM推理优化][LLM Infra整理]📖PagedAttention论文新鲜出炉 @DefTruth ⭐️

📚 CV推理部署|C++|算法|技术随笔 (本人作者) (©️back👆🏻)

📖 类型-标题 📖 作者 📖 推荐
[推理部署][CV/NLP]📖FastDeploy三行代码搞定150+ CV、NLP模型部署 @DefTruth ⭐️
[推理部署][CV]📖如何在lite.ai.toolkit(3.6k+ stars)中增加您的模型? @DefTruth ⭐️⭐️
[推理部署][CV]📖美团 YOLOv6 ORT/MNN/TNN/NCNN C++推理部署 @DefTruth ⭐️⭐️
[推理部署][ONNX]📖ONNX推理加速技术文档-杂记 @DefTruth ⭐️
[推理部署][TensorFlow]📖Mac源码编译TensorFlow C++指北 @DefTruth ⭐️
[推理部署][CV]📖1Mb!头部姿态估计: FSANet,一个小而美的模型(C++) @DefTruth ⭐️
[推理部署][CV]📖opencv+ffmpeg编译打包全解指南 @DefTruth ⭐️⭐️
[推理部署][CV]📖RobustVideoMatting视频抠图静态ONNX模型转换 @DefTruth ⭐️
[推理部署][CV]📖190Kb!SSRNet年龄检测详细解读(含C++工程) @DefTruth ⭐️
[推理部署][CV]📖MGMatting(CVPR2021)人像抠图C++应用记录 @DefTruth ⭐️
[推理部署][CV]📖超准确人脸检测(带关键点)YOLO5Face C++工程详细记录 @DefTruth ⭐️⭐️
[推理部署][ORT]📖解决: ONNXRuntime(Python) GPU 部署配置记录 @DefTruth ⭐️
[推理部署][CV]📖记录SCRFD(CVPR2021)人脸检测C++工程化(含docker镜像) @DefTruth ⭐️⭐️
[推理部署][NCNN]📖野路子:记录一个解决onnx转ncnn时op不支持的trick @DefTruth ⭐️
[推理部署][CV]📖升级版轻量级NanoDet-Plus MNN/TNN/NCNN/ORT C++工程记录 @DefTruth ⭐️⭐️
[推理部署][CV]📖超轻量级NanoDet MNN/TNN/NCNN/ORT C++工程记录 @DefTruth ⭐️
[推理部署][CV]📖详细记录MGMatting之MNN、TNN和ORT C++移植 @DefTruth ⭐️⭐️
[推理部署][CV]📖YOLOX NCNN/MNN/TNN/ONNXRuntime C++工程简记 @DefTruth ⭐️
[推理部署][TNN]📖手动修改YoloX的tnnproto记录-TNN @DefTruth ⭐️
[推理部署][ORT]📖全网最详细 ONNXRuntime C++/Java/Python 资料! @DefTruth ⭐️
[推理部署][CV]📖RobustVideoMatting: C++工程化记录-实现篇 @DefTruth ⭐️⭐️
[推理部署][CV]📖RobustVideoMatting: C++工程化记录-应用篇 @DefTruth ⭐️⭐️
[推理部署][ORT]📖ONNXRuntime C++ CMake 工程分析及编译 @DefTruth ⭐️⭐️
[推理部署][ORT]📖如何使用ORT C++ API处理NCHW和NHWC输入? @DefTruth ⭐️
[推理部署][TNN]📖tnn-convert搭建简记-YOLOP转TNN @DefTruth ⭐️
[推理部署][CV]📖YOLOP ONNXRuntime C++工程化记录 @DefTruth ⭐️⭐️
[推理部署][NCNN]📖超有用NCNN参考资料整理 @DefTruth ⭐️
[推理部署][MNN]📖超有用MNN参考资料整理 @DefTruth ⭐️
[推理部署][TNN]📖超有用TNN参考资料整理 @DefTruth ⭐️
[推理部署][ONNX]📖超有用ONNX参考资料整理 @DefTruth ⭐️
[推理部署][ONNX]📖超有用ONNX模型结构参考资料整理 @DefTruth ⭐️
[推理部署][OpenCV-DNN]📖超有用OpenCV-DNN参考资料整理 @DefTruth ⭐️
[推理部署][Tensorflow]📖超有用Tensorflow C++工程化知识点 @DefTruth ⭐️
[推理部署][模型转换]📖深度学习模型转换资料整理 @DefTruth ⭐️
[技术随笔][C++][CMake]📖超有用CMake参考资料整理 @DefTruth ⭐️⭐️
[技术随笔][C++][3W字]📖静态链接和静态库实践指北-原理篇 @DefTruth ⭐️⭐️⭐️
[技术随笔][C++]📖Mac下C++内存检查指北(Valgrind VS Asan) @DefTruth ⭐️
[技术随笔][CV]📖torchlm: 人脸关键点检测库 @DefTruth ⭐️⭐️
[技术随笔][ML]📖《统计学习方法-李航: 笔记-从原理到实现-基于R》 @DefTruth ⭐️⭐️
[技术随笔][Git]📖如何优雅地git clone和git submodule? @DefTruth ⭐️
[技术随笔][3D]📖人脸重建3D参考资料整理 @DefTruth ⭐️
[技术随笔][3D]📖BlendShapes参考资料整理 @DefTruth ⭐️
[技术随笔][3D]📖从源码安装Pytorch3D详细记录及学习资料 @DefTruth ⭐️
[技术随笔][ML]📖200页:《统计学习方法:李航》笔记 -从原理到实现 @DefTruth ⭐️⭐️

📚 CUTLASS|CuTe|NCCL|CUDA|文章推荐 (其他作者) (©️back👆🏻)

💡说明: 本小节整理一些自己比较喜欢的文章。欢迎大家提PR推荐更多优秀的文章!

📖 类型-标题 📖 作者 📖 推荐
[cute系列详解][入门]📖cutlass cute 101 @朱小霖 ⭐️⭐️⭐️
[cute系列详解][入门]📖CUTLASS 2.x & CUTLASS 3.x Intro 学习笔记 @BBuf ⭐️⭐️⭐️
[cute系列详解][Layout]📖cute 之 Layout @reed ⭐️⭐️⭐️
[cute系列详解][Layout]📖cute Layout 的代数和几何解释 @reed ⭐️⭐️⭐️
[cute系列详解][Tensor]📖cute 之 Tensor @reed ⭐️⭐️⭐️
[cute系列详解][MMA]📖cute 之 MMA抽象 @reed ⭐️⭐️⭐️
[cute系列详解][Copy]📖cute 之 Copy抽象 @reed ⭐️⭐️⭐️
[cute系列详解][Swizzle]📖cute 之 Swizzle @reed ⭐️⭐️⭐️
[cute系列详解][Swizzle]📖cute Swizzle细谈 @进击的Killua ⭐️⭐️⭐️
[cute系列详解][Swizzle]📖cutlass swizzle机制解析(一) @Titus ⭐️⭐️⭐️
[cute系列详解][Swizzle]📖cutlass swizzle机制解析(二) @Titus ⭐️⭐️⭐️
[cute系列详解][Swizzle]📖CUDA避免smem bank conflict的swizzle机制解析 @frankshi ⭐️⭐️⭐️
[cute系列详解][GEMM]📖cute 之 简单GEMM实现 @reed ⭐️⭐️⭐️
[cute系列详解][GEMM]📖cute 之 GEMM流水线 @reed ⭐️⭐️⭐️
[cute系列详解][GEMM]📖cute 之 高效GEMM实现 @reed ⭐️⭐️⭐️
[cute系列详解][GEMM]📖GEMM流水线: single/multi-stage、pipeline @Titus ⭐️⭐️⭐️
[cute系列详解][GEMM]📖GEMM细节分析(一): ldmatrix的选择 @Anonymous ⭐️⭐️⭐️
[cute系列详解][GEMM]📖GEMM细节分析(二): TiledCopy与cp.async @Anonymous ⭐️⭐️⭐️
[cute系列详解][GEMM]📖GEMM细节分析(三): Swizzle<B,M,S>参数取值 @Anonymous ⭐️⭐️⭐️
[cute系列详解][实践]📖Hopper Mixed GEMM的CUTLASS实现笔记 @BBuf ⭐️⭐️⭐️
[cute系列详解][实践]📖CUTLASS CuTe实战(一): 基础 @进击的Killua ⭐️⭐️⭐️
[cute系列详解][实践]📖CUTLASS CuTe实战(二): 应用 @进击的Killua ⭐️⭐️⭐️
[cute系列详解][实践]📖FlashAttention fp8实现(ada架构) @shengying.wei ⭐️⭐️⭐️
[cute系列详解][实践]📖FlashAttention 笔记: tiny-flash-attention解读 @shengying.wei ⭐️⭐️⭐️
[cute系列详解][实践]📖使用cutlass cute复现flash attention @66RING ⭐️⭐️⭐️
[cutlass教程][入门]📖cutlass 基本认知 @JoeNomad ⭐️⭐️⭐️
[cutlass教程][入门]📖cutlass 软件架构 @JoeNomad ⭐️⭐️⭐️
[cutlass教程][入门]📖CUTLASS 基础介绍 @进击的Killua ⭐️⭐️⭐️
[cutlass教程][入门]📖乱谈CUTLASS GTC2020 SLIDES @zzk again ⭐️⭐️⭐️
[cutlass教程][深入]📖cutlass block swizzle 和 tile iterator @JoeNomad ⭐️⭐️⭐️
[cutlass教程][深入]📖cutlass bank conflict free的smem layout @JoeNomad ⭐️⭐️⭐️
[cutlass教程][深入]📖cutlass 多级流水线 @JoeNomad ⭐️⭐️⭐️
[GPU指令集架构][精解]📖NVidia GPU指令集架构-前言 @reed ⭐️⭐️⭐️
[GPU指令集架构][精解]📖NVidia GPU指令集架构-寄存器 @reed ⭐️⭐️⭐️
[GPU指令集架构][精解]📖NVidia GPU指令集架构-Load和Cache @reed ⭐️⭐️⭐️
[GPU指令集架构][精解]📖NVidia GPU指令集架构-浮点运算 @reed ⭐️⭐️⭐️
[GPU指令集架构][精解]📖NVidia GPU指令集架构-整数运算 @reed ⭐️⭐️⭐️
[GPU指令集架构][精解]📖NVidia GPU指令集架构-比特和逻辑操作 @reed ⭐️⭐️⭐️
[GPU指令集架构][精解]📖NVidia GPU指令集架构-Warp级和Uniform操作 @reed ⭐️⭐️⭐️
[CUDA优化][入门]📖CUDA(一):CUDA 编程基础 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][入门]📖CUDA(二):GPU的内存体系及其优化指南 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖CUDA(三):通用矩阵乘法:从入门到熟练 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖ops(1):LayerNorm 算子的 CUDA 实现与优化 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖ops(2):SoftMax算子的 CUDA 实现 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖ops(3):Cross Entropy 的 CUDA 实现 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖ops(4):AdamW 优化器的 CUDA 实现 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖ops(5):激活函数与残差连接的 CUDA 实现 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖ops(6):embedding 层与 LM head 层的 CUDA 实现 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖ops(7):self-attention 的 CUDA 实现及优化 (上) @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖ops(8):self-attention 的 CUDA 实现及优化 (下) @紫气东来 ⭐️⭐️⭐️
[CUDA优化][实践]📖CUDA(四):使用 CUDA 实现 Transformer 结构 @紫气东来 ⭐️⭐️⭐️
[CUDA优化][Copy]📖Async Copy及Memory Barrier指令的功能与实现 @Frank Wang ⭐️⭐️⭐️
[CUDA优化][GEMV]📖深入浅出GPU优化系列:gemv优化 @有了琦琦的棍子 ⭐️⭐️⭐️
[Tensor Cores]📖Nvidia Tensor Core初探 @木子知 ⭐️⭐️⭐️
[Tensor Cores]📖Nvidia Tensor Core-WMMA API编程入门 @木子知 ⭐️⭐️⭐️
[Tensor Cores]📖Nvidia Tensor Core-MMA PTX编程入门 @木子知 ⭐️⭐️⭐️
[Tensor Cores]📖CUDA Ampere Tensor Core HGEMM 矩阵乘法优化 @nicholaswilde ⭐️⭐️⭐️
[GPU通信架构][精解]📖NVIDIA GPGPU(四)- 通信架构 @Bruce ⭐️⭐️⭐️

©️License (©️back👆🏻)

GNU General Public License v3.0

🎉Contribute (©️back👆🏻)

How to contribute? Star this repo or check 🌤🌤CONTRIBUTE🎉🎉.

📖 References (©️back👆🏻)