|
6 | 6 | from typing import Any |
7 | 7 |
|
8 | 8 | import regex as re |
| 9 | +import torch |
| 10 | +from aiter.ops.triton.quant import dynamic_mxfp4_quant |
9 | 11 |
|
10 | 12 |
|
11 | 13 | def deep_compare(dict1: Any, dict2: Any) -> bool: |
@@ -103,3 +105,108 @@ def _is_equal_or_regex_match( |
103 | 105 | elif target == value: |
104 | 106 | return True |
105 | 107 | return False |
| 108 | + |
| 109 | + |
| 110 | +def quant_to_mxfp4(x): |
| 111 | + """ |
| 112 | + Quant the input tensor x to mxfp4 format |
| 113 | + """ |
| 114 | + h, b, d = x.shape |
| 115 | + x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d)) |
| 116 | + return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) |
| 117 | + |
| 118 | + |
| 119 | +def dequant_mxfp4_to_fp32(x, is_threed): |
| 120 | + """ |
| 121 | + Dequant the input tensor x from mxfp4 format to fp32 format |
| 122 | + """ |
| 123 | + # repeat interleave 2x because we pack mxfp4 in uint8 |
| 124 | + x = x.repeat_interleave(2, dim=-1) |
| 125 | + if is_threed: |
| 126 | + x[..., ::2] = x[..., ::2] & 0xF |
| 127 | + x[..., 1::2] = x[..., 1::2] >> 4 |
| 128 | + else: |
| 129 | + x[:, ::2] = x[:, ::2] & 0xF |
| 130 | + x[:, 1::2] = x[:, 1::2] >> 4 |
| 131 | + |
| 132 | + mxfp4_list = [ |
| 133 | + 0.0, |
| 134 | + 0.5, |
| 135 | + 1.0, |
| 136 | + 1.5, |
| 137 | + 2.0, |
| 138 | + 3.0, |
| 139 | + 4.0, |
| 140 | + 6.0, |
| 141 | + -0.0, |
| 142 | + -0.5, |
| 143 | + -1.0, |
| 144 | + -1.5, |
| 145 | + -2.0, |
| 146 | + -3.0, |
| 147 | + -4.0, |
| 148 | + -6.0, |
| 149 | + ] |
| 150 | + mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda") |
| 151 | + return mxfp4_in_f32[x.long()] |
| 152 | + |
| 153 | + |
| 154 | +def convert_e8m0_to_fp32(x): |
| 155 | + """ |
| 156 | + Convert the input tensor x from e8m0 format to fp32 format |
| 157 | + """ |
| 158 | + # Convert the input tensor `x` (assumed to be in |
| 159 | + # e8m0 format) to float32. e8m0 is a custom 8-bit |
| 160 | + # floating point format with 8 bits for exponent, 0 for mantissa. |
| 161 | + # This means the value is essentially 2^(exponent - 127), |
| 162 | + # similar to how IEEE-754 stores floats. |
| 163 | + |
| 164 | + # Convert x to float32 for computation, and |
| 165 | + # compute the power of 2 by subtracting the bias (127). |
| 166 | + x_f32 = 2 ** ((x.to(torch.float32)) - 127) |
| 167 | + |
| 168 | + # If the exponent value was 255 (i.e., 2^(128)), this |
| 169 | + # is a special case usually used to represent NaN or Inf. |
| 170 | + # Since this custom format has no mantissa, treat 2^128 as NaN. |
| 171 | + x_f32[x_f32 == 128] = float("nan") |
| 172 | + return x_f32 |
| 173 | + |
| 174 | + |
| 175 | +def quark_post_load_weights( |
| 176 | + qk_nope_head_dim: int, |
| 177 | + v_head_dim: int, |
| 178 | + weight: torch.Tensor, |
| 179 | + weight_scale: torch.Tensor | None = None, |
| 180 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 181 | + """ |
| 182 | + Post load weights for quark MXFP4 BMM |
| 183 | + """ |
| 184 | + |
| 185 | + def _quant_and_split_weight(loaded_weight: torch.Tensor): |
| 186 | + W_UK, W_UV = loaded_weight.unflatten( |
| 187 | + 0, (-1, (qk_nope_head_dim + v_head_dim)) |
| 188 | + ).split([qk_nope_head_dim, v_head_dim], dim=1) |
| 189 | + W_UK, W_UK_scale = quant_to_mxfp4(W_UK.transpose(-2, -1)) |
| 190 | + W_UV, W_UV_scale = quant_to_mxfp4(W_UV) |
| 191 | + W_UK_scale = W_UK_scale.contiguous() |
| 192 | + W_UV_scale = W_UV_scale.contiguous() |
| 193 | + return W_UK, W_UK_scale, W_UV, W_UV_scale |
| 194 | + |
| 195 | + # weight: [kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)] |
| 196 | + # for the model with BF16 weight to use MXFP4 BMM, |
| 197 | + # quant the weight to U8 packed format(MXFP4*2) |
| 198 | + if weight.dtype == torch.bfloat16: |
| 199 | + W_UK, W_UK_scale, W_UV, W_UV_scale = _quant_and_split_weight(weight) |
| 200 | + elif weight.dtype == torch.uint8: |
| 201 | + assert weight_scale is not None, ( |
| 202 | + "[Error][ROCm] weight_scale is required for U8 weight" |
| 203 | + ) |
| 204 | + weight = dequant_mxfp4_to_fp32(weight, True).to(torch.bfloat16) |
| 205 | + weight_scale = weight_scale.repeat_interleave(32, dim=-1) |
| 206 | + weight_scale = convert_e8m0_to_fp32(weight_scale).to(torch.bfloat16) |
| 207 | + weight = weight * weight_scale |
| 208 | + W_UK, W_UK_scale, W_UV, W_UV_scale = _quant_and_split_weight(weight) |
| 209 | + else: |
| 210 | + raise ValueError("[Error][ROCm] Unsupported weight dtype: ", weight.dtype) |
| 211 | + |
| 212 | + return W_UK, W_UK_scale, W_UV, W_UV_scale |
0 commit comments