Skip to content

[gfx1201] Enable RMSNorm support for gfx1201#4

Open
big-yellow-duck wants to merge 4 commits intomainfrom
rdna4-rmsnorm-support
Open

[gfx1201] Enable RMSNorm support for gfx1201#4
big-yellow-duck wants to merge 4 commits intomainfrom
rdna4-rmsnorm-support

Conversation

@big-yellow-duck
Copy link

Motivation

The RMSNorm kernels in csrc/kernels/rmsnorm_quant_kernels.cu use CDNA-specific inline assembly instructions that are not supported on RDNA4 (gfx1201) architecture. This prevents the RMSNorm operation from working on gfx1201 GPUs. This PR aims to enable RMSNorm support on gfx1201 by replacing unsupported assembly instructions with portable HIP/C++ alternatives.

Technical Details

Changes Overview

Modified csrc/kernels/rmsnorm_quant_kernels.cu to replace CDNA-specific inline assembly with portable implementations for gfx11/gfx12:

  1. Replaced v_pk_mul_f32 inline assembly (lines 146-151, 196-201)

    • Changed from: asm volatile("v_pk_mul_f32 %0, %1, %2" ...)
    • Changed to: Standard float multiplication with __gfx11__ || __gfx12__ guard
    • This instruction is not supported on RDNA4 (gfx12xx) architecture
  2. Replaced bf16 unpacking inline assembly (lines 162-176)

    • Changed from: v_lshlrev_b32_e32 and v_and_b32_e32 instructions
    • Changed to: ck_tile::bit_cast with shift operations for unpacking bf16 values
    • Provides equivalent functionality using portable HIP/C++ code
  3. Replaced fp16 unpacking inline assembly (lines 180-194)

    • Changed from: v_cvt_f32_f16_e32 and v_cvt_f32_f16_sdwa instructions
    • Changed to: ck_tile::bit_cast with shift operations for unpacking fp16 values
    • SDWA (Sub-Dword Addressing) instructions are CDNA-specific

Compatibility

  • CDNA (gfx90a, gfx942): No functional change - continues to use optimized inline assembly
  • RDNA4 (gfx1201): Now uses portable HIP/C++ implementation
  • All changes are guarded by preprocessor conditions (#if defined(__gfx11__) || defined(__gfx12__)) to ensure optimal performance on both architectures

Test Plan

Run the RMSNorm test suites to validate the changes:

# Test fused RMSNorm with add and quantization (FP8)
python op_tests/test_rmsnorm2dFusedAddQuant.py --mode 7 -q fp8

# Test standard RMSNorm operations
python op_tests/test_rmsnorm2d.py

Tests cover:

  • RMSNorm with residual addition
  • FP8 quantization paths
  • Various hidden dimension sizes (up to 8192)
  • Both bf16 and fp16 input data types
  • Per-token and per-channel quantization modes

Test Result

Test result for python op_tests/test_rmsnorm2dFusedAddQuant.py --mode 7 -q fp8

m n quant_type add_residual dtype quant_dtype smoothquant torch us hip us hip err hip bw(GB/s)
8 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 33.2036 2.18041 0.0317383 25.4776
256 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 100.013 4.89922 0.0306206 350.774
2048 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 163.156 22.8312 0.0308919 601.581
2560 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 141.778 23.6688 0.0309967 725.346
32768 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 2402.92 390.831 0.0310231 562.208
8 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 24.3369 2.30625 0.0294189 48.1748
256 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 84.8793 14.8431 0.0314503 231.558
2048 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 257.192 39.1746 0.0309095 701.209
2560 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 308.477 47.5147 0.0309946 722.641
32768 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 5155.18 777.68 0.0310699 565.087
8 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 53.5412 2.74501 0.0287476 80.9491
256 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 105.035 18.4179 0.0307817 373.228
2048 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 400.869 102.266 0.0310012 537.218
2560 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 590.208 126.793 0.0310267 541.608
32768 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 9250.57 1574.4 0.0310284 558.254
8 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 36.7148 2.96228 0.0305786 150.024
256 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 116.373 31.7526 0.0309405 432.977
2048 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 1132.48 200.516 0.0310664 547.979
2560 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 1471.56 249.542 0.0309757 550.385
32768 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 18541.1 3156.56 0.0310378 556.88

Submission Checklist

@github-actions
Copy link

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 4 --add-label <label>


return local;
}
} No newline at end of file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@big-yellow-duck remove unnecessary changes.

asm volatile( "s_mov_b32 m0 %0\n\t"
"buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t"
::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0");
asm volatile("s_mov_b32 m0, %0; buffer_load_dword %1, %2, 0 offen lds;" :: "s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc) : "memory");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@big-yellow-duck is this change necessary? It looks the same. If it is just a alignment issue, please revert this to original code. Let's not reformat the code to avoid confusion. This makes the reviewer thinks you have modified the instruction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants