Skip to content

[gfx1201] Enable quantization kernels for gfx1201#3

Open
big-yellow-duck wants to merge 2 commits intomainfrom
rdna4-quant-support
Open

[gfx1201] Enable quantization kernels for gfx1201#3
big-yellow-duck wants to merge 2 commits intomainfrom
rdna4-quant-support

Conversation

@big-yellow-duck
Copy link

@big-yellow-duck big-yellow-duck commented Mar 13, 2026

Motivation

FP8 quantization operations fail on AMD gfx1201 (RDNA4) architecture due to three compatibility issues:

  1. FP8 dtype is not registered for gfx1201 in the dtype mapping
  2. v_pk_mul_f32 assembly instruction is not supported on gfx11/gfx12
  3. DPP broadcast operations (0x142, 0x143) used in hip reduce are not supported on gfx11/gfx12

This PR enables FP8 quantization support on gfx1201 by addressing these incompatibilities.

Technical Details

1. FP8 Dtype Registration (aiter/utility/dtypes.py)

Added gfx1201 to the default FP8 dtype mapping to enable torch.float8_e4m3fn support on RDNA4.

2. Scalar Multiplication Fallback (csrc/include/ck_tile/vec_convert.h)

The v_pk_mul_f32 assembly instruction is not supported on gfx11/gfx12. Added amd_scalar_mul_f32() function as a portable fallback:

CK_TILE_DEVICE fp32x2_v amd_scalar_mul_f32(fp32x2_v a, fp32x2_t b){
    fp32x2_v c;
    c[0] = a[0] * b[0];
    c[1] = a[1] * b[1];
    return c;
}

The conversion functions fp32x2_t_to_fp8x2_t and fp32x2_t_to_int8x2_t now conditionally use the scalar path:

#if defined(__gfx11__) || defined(__gfx12__)
    tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#else
    tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#endif

3. DPP Broadcast Replacement (csrc/include/hip_reduce.h)

DPP broadcast operations are not supported on gfx11/gfx12. Replaced with rocprim::warp_shuffle() for cross-lane communication in:

  • wave_reduce() - for WarpSize > 16 and WarpSize > 32 reductions
  • multithread_reduce() - for 16-thread and 32-thread reduction paths

Example change:

#if defined(__gfx12__) || defined(__gfx11__)
    // Use shuffle for gfx12 instead of DPP broadcast
    T v_remote = rocprim::warp_shuffle(local, 15, WarpSize);
    local      = reduce_op(v_remote, local);
#else
    // row_bcast:15
    local = reduce_op(rocprim::detail::warp_move_dpp<T, 0x142>(local), local);
#endif

4. Naive load to LDS fallback (csrc/kernels/quant_kernels.cu)

gfx12x Fallback to naive loading from global memory to LDS in smooth_per_token_scaled_quant_kernel.

for(int i = 0; i < async_load_num; i++)
        {
            #if defined(__gfx12__)
                int idx = threadIdx.x + i * block_size;
                if(idx < smooth_scale_map_hash_size)
                {
                    // RDNA4 doesn't support buffer_load_* with LDS modifier
                    // Use standard global load to VGPR then write to LDS
                    smooth_scale_map_hash_shared[idx] = smooth_scale_map_hash[idx];
                }
            #else
                const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size))));
                uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int);
                asm volatile( "s_mov_b32 m0 %0\n\t"
                "buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t"
                ::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0");
            #endif
        }

Test Plan

Run the quantization test suite with various tensor sizes:

python op_tests/test_quant.py -m 1 2 16 32 64 128 192 256 512 1024 16384

Test Result

All quantization tests pass successfully on gfx1201:

m n q_type q_dtype h_dtype triton dq triton dq err hip dq hip dq err
1 4096 2 torch.float8_e4m3fn torch.bfloat16 3.64439 0 1.93756 0
2 4096 2 torch.float8_e4m3fn torch.bfloat16 3.64966 0 1.96232 0
16 4096 2 torch.float8_e4m3fn torch.bfloat16 3.77261 0.000518799 2.11611 0
32 4096 2 torch.float8_e4m3fn torch.bfloat16 4.16686 0.000236511 2.2483 0
64 4096 2 torch.float8_e4m3fn torch.bfloat16 4.45646 0.000331879 2.55272 0
128 4096 2 torch.float8_e4m3fn torch.bfloat16 6.31186 0.000110626 11.6767 0
192 4096 2 torch.float8_e4m3fn torch.bfloat16 7.81408 0.000104268 15.3828 0
256 4096 2 torch.float8_e4m3fn torch.bfloat16 10.096 0.000151634 12.9491 0
512 4096 2 torch.float8_e4m3fn torch.bfloat16 16.904 0.000132084 13.9252 0
1024 4096 2 torch.float8_e4m3fn torch.bfloat16 28.8941 0.000131607 20.9024 0
16384 4096 2 torch.float8_e4m3fn torch.bfloat16 332.715 9.91374e-05 329.895 2.98023e-08
1 8192 2 torch.float8_e4m3fn torch.bfloat16 6.09301 0 2.18569 0
2 8192 2 torch.float8_e4m3fn torch.bfloat16 5.53158 0 2.09682 0
16 8192 2 torch.float8_e4m3fn torch.bfloat16 6.17309 0 2.22647 0
32 8192 2 torch.float8_e4m3fn torch.bfloat16 6.34547 0.000667572 2.41126 0
64 8192 2 torch.float8_e4m3fn torch.bfloat16 7.87828 0 15.0149 0
128 8192 2 torch.float8_e4m3fn torch.bfloat16 11.1925 0.00028801 15.949 0
192 8192 2 torch.float8_e4m3fn torch.bfloat16 14.0472 0.000234604 15.7946 0
256 8192 2 torch.float8_e4m3fn torch.bfloat16 19.2459 0.000182629 12.6191 0
512 8192 2 torch.float8_e4m3fn torch.bfloat16 29.9609 0.000250578 19.6405 0
1024 8192 2 torch.float8_e4m3fn torch.bfloat16 52.7824 0.000243187 40.564 1.19209e-07
16384 8192 2 torch.float8_e4m3fn torch.bfloat16 672.725 0.000171259 660.283 8.9407e-08
The scalar multiplication fallback and warp shuffle replacements provide correct functionality while maintaining compatibility with the RDNA4 architecture.

Submission Checklist

@github-actions
Copy link

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 3 --add-label <label>

#undef CK_TILE_TYPE_CONVERT

} // namespace ck_tile
} // namespace ck_tile No newline at end of file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line change @big-yellow-duck


return local;
}
} No newline at end of file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@big-yellow-duck remove this unnecessary line change

asm volatile( "s_mov_b32 m0 %0\n\t"
"buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t"
::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0");
asm volatile("s_mov_b32 m0, %0; buffer_load_dword %1, %2, 0 offen lds;" :: "s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc) : "memory");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@big-yellow-duck is this change necessary? It looks the same. If it is just a alignment issue, please revert this to original code. Let's not reformat the code to avoid confusion. This makes the reviewer thinks you have modified the instruction.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the patch was to fix the smooth quant kernel on gfx1201 but it will affect gfx9 gpus. changed to another implementation for gfx12 cards

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants