@@ -541,6 +541,26 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
541
541
layer .register_parameter ("w2_weight_scale_inv" , w2_weight_scale )
542
542
assert self .quant_config .activation_scheme == "dynamic"
543
543
544
+ k = hidden_size
545
+ n = intermediate_size_per_partition
546
+ device = layer .w13_weight .device
547
+ self .ab_strides1 = torch .full ((num_experts , ),
548
+ k ,
549
+ device = device ,
550
+ dtype = torch .int64 )
551
+ self .c_strides1 = torch .full ((num_experts , ),
552
+ 2 * n ,
553
+ device = device ,
554
+ dtype = torch .int64 )
555
+ self .ab_strides2 = torch .full ((num_experts , ),
556
+ n ,
557
+ device = device ,
558
+ dtype = torch .int64 )
559
+ self .c_strides2 = torch .full ((num_experts , ),
560
+ k ,
561
+ device = device ,
562
+ dtype = torch .int64 )
563
+
544
564
# Add the quantization method used (per tensor/grouped/channel)
545
565
# to ensure the weight scales are loaded in properly
546
566
extra_weight_attrs .update (
@@ -577,6 +597,22 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
577
597
layer .w13_input_scale = None
578
598
layer .w2_input_scale = None
579
599
600
+ def fp8_bf16_fp8 (self , fp8_tensor , fp8_scale ):
601
+ blocked_tensor = fp8_tensor .view (
602
+ fp8_tensor .shape [0 ],
603
+ fp8_tensor .shape [1 ] // 128 , 128 ,
604
+ fp8_tensor .shape [2 ] // 128 ,
605
+ 128 ).to (torch .float32 )
606
+ # Because blocked_tensor is 5D, reshape to [B, M//128, 1, N//128, 1]
607
+ dequant_tensor = (blocked_tensor *
608
+ fp8_scale .unsqueeze (2 ).unsqueeze (4 )).view (
609
+ fp8_tensor .shape ).to (torch .bfloat16 ).to (torch .float32 )
610
+
611
+ scale_tensor = torch .abs (dequant_tensor ).max () / 448
612
+ quant_tensor = dequant_tensor / scale_tensor
613
+
614
+ return quant_tensor , scale_tensor
615
+
580
616
def process_weights_after_loading (self , layer : Module ) -> None :
581
617
# Lazy import to avoid importing triton too early.
582
618
from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
@@ -600,13 +636,29 @@ def process_weights_after_loading(self, layer: Module) -> None:
600
636
w2_weight = layer .w2_weight
601
637
w2_weight_scale_inv = layer .w2_weight_scale_inv
602
638
603
- # torch.compile() cannot use Parameter subclasses.
604
- layer .w13_weight = Parameter (w13_weight , requires_grad = False )
605
- layer .w13_weight_scale_inv = Parameter (w13_weight_scale_inv ,
606
- requires_grad = False )
607
- layer .w2_weight = Parameter (w2_weight , requires_grad = False )
608
- layer .w2_weight_scale_inv = Parameter (w2_weight_scale_inv ,
609
- requires_grad = False )
639
+ if not (envs .VLLM_USE_CUTLASS_MOE_FP8 ):
640
+ # torch.compile() cannot use Parameter subclasses.
641
+ layer .w13_weight = Parameter (w13_weight , requires_grad = False )
642
+ layer .w13_weight_scale_inv = Parameter (w13_weight_scale_inv ,
643
+ requires_grad = False )
644
+ layer .w2_weight = Parameter (w2_weight , requires_grad = False )
645
+ layer .w2_weight_scale_inv = Parameter (w2_weight_scale_inv ,
646
+ requires_grad = False )
647
+ else :
648
+ w13_weight , w13_weight_scale_inv = \
649
+ self .fp8_bf16_fp8 (w13_weight , w13_weight_scale_inv )
650
+ w2_weight , w2_weight_scale_inv = \
651
+ self .fp8_bf16_fp8 (w2_weight , w2_weight_scale_inv )
652
+
653
+ w13_weight_scale_inv = w13_weight_scale_inv .repeat (w13_weight .size (0 ))
654
+ w2_weight_scale_inv = w2_weight_scale_inv .repeat (w2_weight .size (0 ))
655
+
656
+ layer .w13_weight .data .copy_ (w13_weight )
657
+ layer .w13_weight_scale_inv = Parameter (w13_weight_scale_inv , requires_grad = False )
658
+ layer .w2_weight .data .copy_ (w2_weight )
659
+
660
+ layer .w2_weight_scale_inv = Parameter (w2_weight_scale_inv , requires_grad = False )
661
+
610
662
if is_rocm_aiter_moe_enabled ():
611
663
# reshaping weights is required for aiter moe kernel.
612
664
shuffled_w13 , shuffled_w2 = shuffle_weights (
@@ -800,27 +852,50 @@ def apply(
800
852
e_score_correction_bias = e_score_correction_bias ,
801
853
)
802
854
803
- return fused_experts (
804
- x ,
805
- layer .w13_weight ,
806
- layer .w2_weight ,
807
- topk_weights = topk_weights ,
808
- topk_ids = topk_ids ,
809
- inplace = True ,
810
- activation = activation ,
811
- use_fp8_w8a8 = True ,
812
- global_num_experts = global_num_experts ,
813
- apply_router_weight_on_input = apply_router_weight_on_input ,
814
- expert_map = expert_map ,
815
- w1_scale = (layer .w13_weight_scale_inv
816
- if self .block_quant else layer .w13_weight_scale ),
817
- w2_scale = (layer .w2_weight_scale_inv
818
- if self .block_quant else layer .w2_weight_scale ),
819
- a1_scale = layer .w13_input_scale ,
820
- a2_scale = layer .w2_input_scale ,
821
- block_shape = self .quant_config .weight_block_size ,
822
- allow_deep_gemm = self .allow_deep_gemm ,
823
- )
855
+
856
+ if not (envs .VLLM_USE_CUTLASS_MOE_FP8 ):
857
+ return fused_experts (
858
+ x ,
859
+ layer .w13_weight ,
860
+ layer .w2_weight ,
861
+ topk_weights = topk_weights ,
862
+ topk_ids = topk_ids ,
863
+ inplace = True ,
864
+ activation = activation ,
865
+ use_fp8_w8a8 = True ,
866
+ global_num_experts = global_num_experts ,
867
+ apply_router_weight_on_input = apply_router_weight_on_input ,
868
+ expert_map = expert_map ,
869
+ w1_scale = (layer .w13_weight_scale_inv
870
+ if self .block_quant else layer .w13_weight_scale ),
871
+ w2_scale = (layer .w2_weight_scale_inv
872
+ if self .block_quant else layer .w2_weight_scale ),
873
+ a1_scale = layer .w13_input_scale ,
874
+ a2_scale = layer .w2_input_scale ,
875
+ block_shape = self .quant_config .weight_block_size ,
876
+ allow_deep_gemm = self .allow_deep_gemm ,
877
+ )
878
+ else :
879
+ from vllm .model_executor .layers .fused_moe import cutlass_moe_fp8
880
+ return cutlass_moe_fp8 (
881
+ x ,
882
+ layer .w13_weight .transpose (1 , 2 ),
883
+ layer .w2_weight .transpose (1 , 2 ),
884
+ layer .w13_weight_scale_inv ,
885
+ layer .w2_weight_scale_inv ,
886
+ topk_weights ,
887
+ topk_ids ,
888
+ self .ab_strides1 ,
889
+ self .c_strides1 ,
890
+ self .ab_strides2 ,
891
+ self .c_strides2 ,
892
+ a1_scale = layer .w13_input_scale ,
893
+ a2_scale = layer .w2_input_scale ,
894
+ out_dtype = x .dtype ,
895
+ expert_map = expert_map ,
896
+ apply_router_weight_on_input = apply_router_weight_on_input ,
897
+ )
898
+
824
899
825
900
826
901
class Fp8KVCacheMethod (BaseKVCacheMethod ):
0 commit comments