16
16
#
17
17
18
18
import math
19
- from typing import Any , Callable , Dict , Optional , Tuple , Union
19
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
20
20
21
21
import torch
22
22
import torch .distributed as dist
32
32
npu_stream_switch , npu_wait_tensor )
33
33
34
34
35
+ def apply_mlp_decode (hidden_states_wrapper : List [torch .Tensor ],
36
+ w1 : torch .Tensor ,
37
+ w1_scale : torch .Tensor ,
38
+ w2 : torch .Tensor ,
39
+ w2_scale : torch .Tensor ,
40
+ group_list : torch .Tensor ,
41
+ dynamic_scale : torch .Tensor = None ,
42
+ group_list_type : int = 1 ) -> torch .Tensor :
43
+ """
44
+ apply MLP: gate_up_proj -> swiglu -> down_proj
45
+ Args:
46
+ hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
47
+ w1: expert weights1 with shape
48
+ (num_experts, hidden_size, intermediate_size * 2)
49
+ w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
50
+ w2: expert weights2 with shape
51
+ (num_experts, intermediate_size, hidden_size)
52
+ w2_scale: weights2 scale with shape (num_experts, hidden_size)
53
+ group_list: number of tokens for each expert, follow cumsum mode, and
54
+ with shape (num_experts).
55
+ transpose_weight:
56
+ w1: (num_experts, intermediate_size * 2, hidden_size) ->
57
+ (num_experts, hidden_size, intermediate_size * 2)
58
+ w2: (num_experts, hidden_size, intermediate_size) ->
59
+ (num_experts, intermediate_size, hidden_size)
60
+ Returns:
61
+ hidden_states: output hidden states after MLP.
62
+ """
63
+
64
+ assert len (hidden_states_wrapper ) == 1
65
+ hidden_states = hidden_states_wrapper .pop ()
66
+ if dynamic_scale is None :
67
+ hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (
68
+ hidden_states )
69
+ else :
70
+ pertoken_scale = dynamic_scale
71
+
72
+ # gmm1: gate_up_proj
73
+ hidden_states = torch_npu .npu_grouped_matmul (
74
+ x = [hidden_states ],
75
+ weight = [w1 ],
76
+ split_item = 3 ,
77
+ group_list_type = group_list_type ,
78
+ group_type = 0 ,
79
+ group_list = group_list ,
80
+ output_dtype = torch .int32 )[0 ]
81
+
82
+ # act_fn: swiglu
83
+ hidden_states , swiglu_out_scale = torch_npu .npu_dequant_swiglu_quant (
84
+ x = hidden_states ,
85
+ weight_scale = w1_scale ,
86
+ activation_scale = pertoken_scale ,
87
+ bias = None ,
88
+ quant_scale = None ,
89
+ quant_offset = None ,
90
+ group_index = group_list ,
91
+ activate_left = True ,
92
+ quant_mode = 1 ,
93
+ )
94
+
95
+ # gmm2: down_proj
96
+ hidden_states = torch_npu .npu_grouped_matmul (
97
+ x = [hidden_states ],
98
+ weight = [w2 ],
99
+ scale = [w2_scale ],
100
+ per_token_scale = [swiglu_out_scale ],
101
+ split_item = 2 ,
102
+ group_list_type = group_list_type ,
103
+ group_type = 0 ,
104
+ group_list = group_list ,
105
+ output_dtype = w2_scale .dtype )[0 ]
106
+ return hidden_states
107
+
108
+
35
109
def apply_mlp (hidden_states : torch .Tensor ,
36
110
w1 : torch .Tensor ,
37
111
w1_scale : torch .Tensor ,
@@ -138,7 +212,9 @@ def fused_experts_with_mc2(
138
212
shared_experts : Optional [Any ] = None ,
139
213
is_torchair : bool = False ,
140
214
w1_scale_bias : torch .Tensor = None ,
141
- w2_scale_bias : torch .Tensor = None
215
+ w2_scale_bias : torch .Tensor = None ,
216
+ quantized_x_for_share : Optional [Any ] = None ,
217
+ dynamic_scale_for_share : Optional [Any ] = None ,
142
218
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
143
219
if log2phy :
144
220
topk_ids = log2phy [topk_ids ]
@@ -193,21 +269,19 @@ def fused_experts_with_mc2(
193
269
194
270
if shared_experts is not None :
195
271
with npu_stream_switch ("moe_secondary" , 0 ):
196
- npu_wait_tensor (hidden_states , topk_weights )
197
- shared_gate_up , _ = shared_experts .gate_up_proj ( hidden_states )
198
- npu_wait_tensor ( shared_gate_up [ 0 ], expand_x )
199
- shared_act = shared_experts . act_fn ( shared_gate_up )
272
+ npu_wait_tensor (quantized_x_for_share , expand_x )
273
+ shared_act_out = shared_experts .act_fn (
274
+ ( quantized_x_for_share , dynamic_scale_for_share ) )
275
+ shared_act , swiglu_out_scale = shared_act_out [ 0 ], shared_act_out [ 1 ]
200
276
201
277
# `expand_x` will be disposed in the `apply_mlp` function
202
- down_out_list = apply_mlp (expand_x ,
203
- w1 ,
204
- w1_scale ,
205
- w2 ,
206
- w2_scale ,
207
- expert_token_nums ,
208
- dynamic_scale = dynamic_scale ,
209
- w1_scale_bias = w1_scale_bias ,
210
- w2_scale_bias = w2_scale_bias )
278
+ down_out_list = apply_mlp_decode ([expand_x ],
279
+ w1 ,
280
+ w1_scale ,
281
+ w2 ,
282
+ w2_scale ,
283
+ expert_token_nums ,
284
+ dynamic_scale = dynamic_scale )
211
285
212
286
# moeCombine
213
287
kwargs_mc2 = {
@@ -244,8 +318,9 @@ def fused_experts_with_mc2(
244
318
return hidden_states
245
319
else :
246
320
with npu_stream_switch ("moe_secondary" , 0 ):
247
- npu_wait_tensor (shared_act [0 ], down_out_list )
248
- shared_output , _ = shared_experts .down_proj (shared_act )
321
+ npu_wait_tensor (shared_act , down_out_list )
322
+ shared_output , _ = shared_experts .down_proj (
323
+ (shared_act , swiglu_out_scale ))
249
324
return hidden_states , shared_output
250
325
251
326
@@ -661,6 +736,8 @@ def apply(
661
736
log2phy : torch .Tensor = None ,
662
737
global_redundant_expert_num : int = 0 ,
663
738
shared_experts : Optional [Any ] = None ,
739
+ quantized_x_for_share : Optional [Any ] = None ,
740
+ dynamic_scale_for_share : Optional [Any ] = None ,
664
741
** kwargs ,
665
742
) -> torch .Tensor :
666
743
assert router_logits .shape [
@@ -695,6 +772,16 @@ def apply(
695
772
e_score_correction_bias = e_score_correction_bias ,
696
773
)
697
774
775
+ fused_moe_state = get_forward_context ().fused_moe_state
776
+ shared_gate_up , shared_dequant_scale = None , None
777
+ if shared_experts is not None and fused_moe_state == FusedMoEState .MC2 :
778
+ with npu_stream_switch ("moe_secondary" , 0 ):
779
+ npu_wait_tensor (quantized_x_for_share , router_logits )
780
+ share_up_out , _ = shared_experts .gate_up_proj (
781
+ (quantized_x_for_share , dynamic_scale_for_share ))
782
+ shared_gate_up , shared_dequant_scale = share_up_out [
783
+ 0 ], share_up_out [1 ]
784
+
698
785
# this is a naive implementation for experts load balance so as
699
786
# to avoid accumulating too much tokens on a single rank.
700
787
# currently it is only activated when doing profile runs.
@@ -703,13 +790,12 @@ def apply(
703
790
704
791
topk_weights = topk_weights .to (x .dtype )
705
792
706
- fused_moe_state = get_forward_context ().fused_moe_state
707
793
if fused_moe_state == FusedMoEState .MC2 :
708
794
return fused_experts_with_mc2 (
709
795
hidden_states = x ,
710
796
w1 = layer .w13_weight ,
711
797
w2 = layer .w2_weight ,
712
- w1_scale = layer .w13_weight_scale ,
798
+ w1_scale = layer .w13_weight_scale_fp32 ,
713
799
w2_scale = layer .w2_weight_scale ,
714
800
topk_weights = topk_weights ,
715
801
topk_ids = topk_ids ,
@@ -719,7 +805,9 @@ def apply(
719
805
log2phy = log2phy ,
720
806
global_redundant_expert_num = global_redundant_expert_num ,
721
807
shared_experts = shared_experts ,
722
- is_torchair = self .torchair_graph_enabled )
808
+ is_torchair = self .torchair_graph_enabled ,
809
+ quantized_x_for_share = shared_gate_up ,
810
+ dynamic_scale_for_share = shared_dequant_scale )
723
811
elif fused_moe_state == FusedMoEState .AllGather :
724
812
return fused_experts (hidden_states = x ,
725
813
w1 = layer .w13_weight ,
@@ -764,6 +852,8 @@ def process_weights_after_loading(self, layer):
764
852
layer .w2_weight .data , ACL_FORMAT_FRACTAL_NZ )
765
853
layer .w13_weight_scale .data = layer .w13_weight_scale .data .view (
766
854
layer .w13_weight_scale .data .shape [0 ], - 1 )
855
+ layer .w13_weight_scale_fp32 = layer .w13_weight_scale .data .to (
856
+ torch .float32 )
767
857
layer .w13_weight_offset .data = layer .w13_weight_offset .data .view (
768
858
layer .w13_weight_offset .data .shape [0 ], - 1 )
769
859
layer .w2_weight_scale .data = layer .w2_weight_scale .data .view (
0 commit comments