Skip to content

Commit 41022d2

Browse files
committed
[FP8][Kernel] Add Cutlass integration for MoE FP8
Introduced optional support for using Cutlass kernels in the MoE FP8 execution path by converting the per-block scaling format into a per-tensor equivalent, making it compatible with the existing Cutlass kernel interface.
1 parent 5be9ad1 commit 41022d2

File tree

1 file changed

+103
-28
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+103
-28
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 103 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,26 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
541541
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
542542
assert self.quant_config.activation_scheme == "dynamic"
543543

544+
k = hidden_size
545+
n = intermediate_size_per_partition
546+
device = layer.w13_weight.device
547+
self.ab_strides1 = torch.full((num_experts, ),
548+
k,
549+
device=device,
550+
dtype=torch.int64)
551+
self.c_strides1 = torch.full((num_experts, ),
552+
2 * n,
553+
device=device,
554+
dtype=torch.int64)
555+
self.ab_strides2 = torch.full((num_experts, ),
556+
n,
557+
device=device,
558+
dtype=torch.int64)
559+
self.c_strides2 = torch.full((num_experts, ),
560+
k,
561+
device=device,
562+
dtype=torch.int64)
563+
544564
# Add the quantization method used (per tensor/grouped/channel)
545565
# to ensure the weight scales are loaded in properly
546566
extra_weight_attrs.update(
@@ -577,6 +597,22 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
577597
layer.w13_input_scale = None
578598
layer.w2_input_scale = None
579599

600+
def fp8_bf16_fp8(self, fp8_tensor, fp8_scale):
601+
blocked_tensor = fp8_tensor.view(
602+
fp8_tensor.shape[0],
603+
fp8_tensor.shape[1] // 128, 128,
604+
fp8_tensor.shape[2] // 128,
605+
128).to(torch.float32)
606+
# Because blocked_tensor is 5D, reshape to [B, M//128, 1, N//128, 1]
607+
dequant_tensor = (blocked_tensor *
608+
fp8_scale.unsqueeze(2).unsqueeze(4)).view(
609+
fp8_tensor.shape).to(torch.bfloat16).to(torch.float32)
610+
611+
scale_tensor = torch.abs(dequant_tensor).max() / 448
612+
quant_tensor = dequant_tensor / scale_tensor
613+
614+
return quant_tensor, scale_tensor
615+
580616
def process_weights_after_loading(self, layer: Module) -> None:
581617
# Lazy import to avoid importing triton too early.
582618
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
@@ -600,13 +636,29 @@ def process_weights_after_loading(self, layer: Module) -> None:
600636
w2_weight = layer.w2_weight
601637
w2_weight_scale_inv = layer.w2_weight_scale_inv
602638

603-
# torch.compile() cannot use Parameter subclasses.
604-
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
605-
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
606-
requires_grad=False)
607-
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
608-
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
609-
requires_grad=False)
639+
if not (envs.VLLM_USE_CUTLASS_MOE_FP8):
640+
# torch.compile() cannot use Parameter subclasses.
641+
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
642+
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
643+
requires_grad=False)
644+
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
645+
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
646+
requires_grad=False)
647+
else:
648+
w13_weight, w13_weight_scale_inv = \
649+
self.fp8_bf16_fp8(w13_weight, w13_weight_scale_inv)
650+
w2_weight, w2_weight_scale_inv = \
651+
self.fp8_bf16_fp8(w2_weight, w2_weight_scale_inv)
652+
653+
w13_weight_scale_inv = w13_weight_scale_inv.repeat(w13_weight.size(0))
654+
w2_weight_scale_inv = w2_weight_scale_inv.repeat(w2_weight.size(0))
655+
656+
layer.w13_weight.data.copy_(w13_weight)
657+
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv, requires_grad=False)
658+
layer.w2_weight.data.copy_(w2_weight)
659+
660+
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, requires_grad=False)
661+
610662
if is_rocm_aiter_moe_enabled():
611663
# reshaping weights is required for aiter moe kernel.
612664
shuffled_w13, shuffled_w2 = shuffle_weights(
@@ -800,27 +852,50 @@ def apply(
800852
e_score_correction_bias=e_score_correction_bias,
801853
)
802854

803-
return fused_experts(
804-
x,
805-
layer.w13_weight,
806-
layer.w2_weight,
807-
topk_weights=topk_weights,
808-
topk_ids=topk_ids,
809-
inplace=True,
810-
activation=activation,
811-
use_fp8_w8a8=True,
812-
global_num_experts=global_num_experts,
813-
apply_router_weight_on_input=apply_router_weight_on_input,
814-
expert_map=expert_map,
815-
w1_scale=(layer.w13_weight_scale_inv
816-
if self.block_quant else layer.w13_weight_scale),
817-
w2_scale=(layer.w2_weight_scale_inv
818-
if self.block_quant else layer.w2_weight_scale),
819-
a1_scale=layer.w13_input_scale,
820-
a2_scale=layer.w2_input_scale,
821-
block_shape=self.quant_config.weight_block_size,
822-
allow_deep_gemm=self.allow_deep_gemm,
823-
)
855+
856+
if not (envs.VLLM_USE_CUTLASS_MOE_FP8):
857+
return fused_experts(
858+
x,
859+
layer.w13_weight,
860+
layer.w2_weight,
861+
topk_weights=topk_weights,
862+
topk_ids=topk_ids,
863+
inplace=True,
864+
activation=activation,
865+
use_fp8_w8a8=True,
866+
global_num_experts=global_num_experts,
867+
apply_router_weight_on_input=apply_router_weight_on_input,
868+
expert_map=expert_map,
869+
w1_scale=(layer.w13_weight_scale_inv
870+
if self.block_quant else layer.w13_weight_scale),
871+
w2_scale=(layer.w2_weight_scale_inv
872+
if self.block_quant else layer.w2_weight_scale),
873+
a1_scale=layer.w13_input_scale,
874+
a2_scale=layer.w2_input_scale,
875+
block_shape=self.quant_config.weight_block_size,
876+
allow_deep_gemm=self.allow_deep_gemm,
877+
)
878+
else:
879+
from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8
880+
return cutlass_moe_fp8(
881+
x,
882+
layer.w13_weight.transpose(1, 2),
883+
layer.w2_weight.transpose(1, 2),
884+
layer.w13_weight_scale_inv,
885+
layer.w2_weight_scale_inv,
886+
topk_weights,
887+
topk_ids,
888+
self.ab_strides1,
889+
self.c_strides1,
890+
self.ab_strides2,
891+
self.c_strides2,
892+
a1_scale=layer.w13_input_scale,
893+
a2_scale=layer.w2_input_scale,
894+
out_dtype=x.dtype,
895+
expert_map=expert_map,
896+
apply_router_weight_on_input=apply_router_weight_on_input,
897+
)
898+
824899

825900

826901
class Fp8KVCacheMethod(BaseKVCacheMethod):

0 commit comments

Comments
 (0)