Skip to content
16 changes: 9 additions & 7 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
typedef __hip_bfloat16 __nv_bfloat16;
#endif

#if defined(__gfx942__)
constexpr float kFp8ScaleDivisor = 224.f;
#else
constexpr float kFp8ScaleDivisor = 448.f;
#endif

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping) {
Expand Down Expand Up @@ -401,8 +407,7 @@ __global__ void concat_and_cache_ds_mla_kernel(
}

// Compute the scale for the tile
float tile_scale = max_abs / 448.f;
tile_scale = fmaxf(tile_scale, FLT_MIN);
float tile_scale = fmaxf(max_abs / kFp8ScaleDivisor, FLT_MIN);

// The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) {
Expand Down Expand Up @@ -471,11 +476,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
#endif
}

#if defined(__gfx942__)
float scale = fmaxf(amax, 1e-4) / 224.0f;
#else
float scale = fmaxf(amax, 1e-4) / 448.0f;
#endif
float scale = fmaxf(amax, 1e-4) / kFp8ScaleDivisor;

if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
Expand Down