Skip to content

Commit b4152da

Browse files
ZhiweiYan-96zejunchen-zejun
authored andcommitted
[ROCm][FP4 BMM] integrate FP4 BMM
Signed-off-by: zejunchen-zejun <[email protected]>
1 parent c5fa4c8 commit b4152da

File tree

5 files changed

+313
-163
lines changed

5 files changed

+313
-163
lines changed

evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ export VLLM_DISABLE_COMPILE_CACHE=1
1111
export VLLM_ROCM_USE_AITER_FP4_ASM_GEMM=0
1212
export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 # disable for acc
1313

14+
export VLLM_ROCM_USE_AITER_BMM=1
15+
1416
export TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1
1517
export TRITON_HIP_USE_ASYNC_COPY=1
1618
export TRITON_HIP_USE_BLOCK_PINGPONG=1
@@ -38,6 +40,7 @@ vllm serve $model_path \
3840
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
3941
--gpu_memory_utilization 0.8 \
4042
--async-scheduling \
43+
--enforce-eager \
4144
--block-size 16 \
4245
--load-format fastsafetensors \
4346
--seed 123 2>&1 | tee log.server.log &

vllm/envs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@
111111
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
112112
VLLM_ROCM_USE_TRITON_ROPE: bool = True
113113
VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE: bool = True
114-
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
114+
VLLM_ROCM_USE_AITER_BMM: bool = True
115115
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
116116
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False
117117
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
@@ -929,8 +929,8 @@ def get_vllm_port() -> int | None:
929929
),
930930
# Whether to use aiter triton fp8 bmm kernel
931931
# By default is enabled.
932-
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
933-
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
932+
"VLLM_ROCM_USE_AITER_BMM": lambda: (
933+
os.getenv("VLLM_ROCM_USE_AITER_BMM", "True").lower() in ("true", "1")
934934
),
935935
# Use AITER triton unified attention for V1 attention
936936
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
@@ -1579,7 +1579,7 @@ def compute_hash() -> str:
15791579
"VLLM_ROCM_USE_AITER_MHA",
15801580
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
15811581
"VLLM_ROCM_USE_TRITON_ROPE",
1582-
"VLLM_ROCM_USE_AITER_FP8BMM",
1582+
"VLLM_ROCM_USE_AITER_BMM",
15831583
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
15841584
"VLLM_ROCM_USE_SKINNY_GEMM",
15851585
"VLLM_ROCM_FP8_PADDING",

vllm/model_executor/layers/quantization/quark/utils.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from typing import Any
77

88
import regex as re
9+
import torch
10+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
911

1012

1113
def deep_compare(dict1: Any, dict2: Any) -> bool:
@@ -103,3 +105,108 @@ def _is_equal_or_regex_match(
103105
elif target == value:
104106
return True
105107
return False
108+
109+
110+
def quant_to_mxfp4(x):
111+
"""
112+
Quant the input tensor x to mxfp4 format
113+
"""
114+
h, b, d = x.shape
115+
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
116+
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
117+
118+
119+
def dequant_mxfp4_to_fp32(x, is_threed):
120+
"""
121+
Dequant the input tensor x from mxfp4 format to fp32 format
122+
"""
123+
# repeat interleave 2x because we pack mxfp4 in uint8
124+
x = x.repeat_interleave(2, dim=-1)
125+
if is_threed:
126+
x[..., ::2] = x[..., ::2] & 0xF
127+
x[..., 1::2] = x[..., 1::2] >> 4
128+
else:
129+
x[:, ::2] = x[:, ::2] & 0xF
130+
x[:, 1::2] = x[:, 1::2] >> 4
131+
132+
mxfp4_list = [
133+
0.0,
134+
0.5,
135+
1.0,
136+
1.5,
137+
2.0,
138+
3.0,
139+
4.0,
140+
6.0,
141+
-0.0,
142+
-0.5,
143+
-1.0,
144+
-1.5,
145+
-2.0,
146+
-3.0,
147+
-4.0,
148+
-6.0,
149+
]
150+
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
151+
return mxfp4_in_f32[x.long()]
152+
153+
154+
def convert_e8m0_to_fp32(x):
155+
"""
156+
Convert the input tensor x from e8m0 format to fp32 format
157+
"""
158+
# Convert the input tensor `x` (assumed to be in
159+
# e8m0 format) to float32. e8m0 is a custom 8-bit
160+
# floating point format with 8 bits for exponent, 0 for mantissa.
161+
# This means the value is essentially 2^(exponent - 127),
162+
# similar to how IEEE-754 stores floats.
163+
164+
# Convert x to float32 for computation, and
165+
# compute the power of 2 by subtracting the bias (127).
166+
x_f32 = 2 ** ((x.to(torch.float32)) - 127)
167+
168+
# If the exponent value was 255 (i.e., 2^(128)), this
169+
# is a special case usually used to represent NaN or Inf.
170+
# Since this custom format has no mantissa, treat 2^128 as NaN.
171+
x_f32[x_f32 == 128] = float("nan")
172+
return x_f32
173+
174+
175+
def quark_post_load_weights(
176+
qk_nope_head_dim: int,
177+
v_head_dim: int,
178+
weight: torch.Tensor,
179+
weight_scale: torch.Tensor | None = None,
180+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
181+
"""
182+
Post load weights for quark MXFP4 BMM
183+
"""
184+
185+
def _quant_and_split_weight(loaded_weight: torch.Tensor):
186+
W_UK, W_UV = loaded_weight.unflatten(
187+
0, (-1, (qk_nope_head_dim + v_head_dim))
188+
).split([qk_nope_head_dim, v_head_dim], dim=1)
189+
W_UK, W_UK_scale = quant_to_mxfp4(W_UK.transpose(-2, -1))
190+
W_UV, W_UV_scale = quant_to_mxfp4(W_UV)
191+
W_UK_scale = W_UK_scale.contiguous()
192+
W_UV_scale = W_UV_scale.contiguous()
193+
return W_UK, W_UK_scale, W_UV, W_UV_scale
194+
195+
# weight: [kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)]
196+
# for the model with BF16 weight to use MXFP4 BMM,
197+
# quant the weight to U8 packed format(MXFP4*2)
198+
if weight.dtype == torch.bfloat16:
199+
W_UK, W_UK_scale, W_UV, W_UV_scale = _quant_and_split_weight(weight)
200+
elif weight.dtype == torch.uint8:
201+
assert weight_scale is not None, (
202+
"[Error][ROCm] weight_scale is required for U8 weight"
203+
)
204+
weight = dequant_mxfp4_to_fp32(weight, True).to(torch.bfloat16)
205+
weight_scale = weight_scale.repeat_interleave(32, dim=-1)
206+
weight_scale = convert_e8m0_to_fp32(weight_scale).to(torch.bfloat16)
207+
weight = weight * weight_scale
208+
W_UK, W_UK_scale, W_UV, W_UV_scale = _quant_and_split_weight(weight)
209+
else:
210+
raise ValueError("[Error][ROCm] Unsupported weight dtype: ", weight.dtype)
211+
212+
return W_UK, W_UK_scale, W_UV, W_UV_scale

0 commit comments

Comments
 (0)