Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aiter/utility/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"gfx942": {"fp8": torch.float8_e4m3fnuz},
"gfx950": {"fp8": torch.float8_e4m3fn},
"gfx1250": {"fp8": torch.float8_e4m3fn},
"gfx1201": {"fp8": torch.float8_e4m3fn},
}

_8bit_fallback = torch.uint8
Expand Down
23 changes: 20 additions & 3 deletions csrc/include/ck_tile/vec_convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ CK_TILE_DEVICE fp32x2_v amd_assembly_pk_mul_f32(fp32x2_v a, fp32x2_t b)
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
return c;
}
// use scalar math for RDNA4/3 without v_pk_mul_f32
CK_TILE_DEVICE fp32x2_v amd_scalar_mul_f32(fp32x2_v a, fp32x2_t b){
fp32x2_v c;
c[0] = a[0] * b[0];
c[1] = a[1] * b[1];
return c;
}
CK_TILE_DEVICE fp8x2_v amd_assembly_cvt_pk_fp8_f32(fp32_t a, fp32_t b)
{
int16x2_t c;
Expand Down Expand Up @@ -145,7 +152,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv
using vec_ti = vector_traits<fp32x2_v>;
constexpr int vec_size = vec_ti::vector_size;
constexpr auto interpret = numeric_traits<fp8_t>::f8_interpret;
fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
fp32x2_v tmp;
#if defined(__gfx11__) || defined(__gfx12__)
tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#else
tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#endif

return (interpret == fp8_interpretation::E4M3_FNUZ) ||
(interpret == fp8_interpretation::E4M3_OCP)
Expand All @@ -155,7 +167,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv
// fp32x2 -> int8x2
CK_TILE_HOST_DEVICE constexpr int8x2_v fp32x2_t_to_int8x2_t(fp32x2_v x, fp32_t inverted_scale)
{
fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
fp32x2_v tmp;
#if defined(__gfx11__) || defined(__gfx12__)
tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#else
tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#endif

int8x2_v out;
out[0] = static_cast<int8_t>(tmp[0]);
Expand Down Expand Up @@ -251,4 +268,4 @@ CK_TILE_TYPE_CONVERT(fp4x2, bf16, 16)
CK_TILE_TYPE_CONVERT(fp4x2, bf16, 32)
#undef CK_TILE_TYPE_CONVERT

} // namespace ck_tile
} // namespace ck_tile
27 changes: 26 additions & 1 deletion csrc/include/hip_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,28 @@ __device__ constexpr T wave_reduce(T local, F reduce_op)

if constexpr(WarpSize > 16)
{
// DPP broadcasts (0x142, 0x143) are not supported on GFX10+ (gfx12 included)
// Use ds_bpermute instead for cross-lane communication
#if defined(__gfx12__) || defined(__gfx11__)
// Use shuffle for gfx12 instead of DPP broadcast
T v_remote = rocprim::warp_shuffle(local, 15, WarpSize);
local = reduce_op(v_remote, local);
#else
// row_bcast:15
local = reduce_op(rocprim::detail::warp_move_dpp<T, 0x142>(local), local);
#endif
}

if constexpr(WarpSize > 32)
{
#if defined(__gfx12__) || defined(__gfx11__)
// Use shuffle for gfx12 instead of DPP broadcast
T v_remote = rocprim::warp_shuffle(local, 31, WarpSize);
local = reduce_op(v_remote, local);
#else
// row_bcast:31
local = reduce_op(rocprim::detail::warp_move_dpp<T, 0x143>(local), local);
#endif
}

if constexpr(threadBroadcast && WarpSize > 4)
Expand Down Expand Up @@ -166,7 +180,12 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num)
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x4e>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x124>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x128>(data), data);
#if defined(__gfx12__) || defined(__gfx11__)
// DPP broadcast 0x142 not supported on gfx12, use shuffle
data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data);
#else
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x142, 0xa>(data), data);
#endif
if constexpr(threadBroadcast)
{
data = rocprim::warp_shuffle(data, thread_num - 1, thread_num);
Expand All @@ -179,8 +198,14 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num)
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x4e>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x124>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x128>(data), data);
#if defined(__gfx12__) || defined(__gfx11__)
// DPP broadcasts not supported on gfx12, use shuffle
data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data);
data = reduce_op(rocprim::warp_shuffle(data, 31, WarpSize), data);
#else
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x142>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x143>(data), data);
#endif
if constexpr(threadBroadcast)
{
data = rocprim::warp_shuffle(data, thread_num - 1, thread_num);
Expand Down Expand Up @@ -231,4 +256,4 @@ __device__ constexpr T block_reduce(T local, F reduce_op)
}

return local;
}
}
4 changes: 1 addition & 3 deletions csrc/kernels/quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,7 @@ __global__ void smooth_per_token_scaled_quant_kernel(DTYPE_O* __restrict__ out,
// buffer_hash.async_load(smooth_scale_map_hash_shared + threadIdx.x + i * block_size, threadIdx.x + i * block_size);
const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size))));
uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int);
asm volatile( "s_mov_b32 m0 %0\n\t"
"buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t"
::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0");
asm volatile("s_mov_b32 m0, %0; buffer_load_dword %1, %2, 0 offen lds;" :: "s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc) : "memory");
}
opus::s_waitcnt_vmcnt(opus::number<0>{});
__syncthreads();
Expand Down
Loading