@@ -67,13 +67,15 @@ def __init__(self,
67
67
activation = config .hidden_act ,
68
68
quant_config = quant_config )
69
69
70
- self .block_sparse_moe = GraniteMoeMoE (
71
- num_experts = config .num_local_experts ,
72
- top_k = config .num_experts_per_tok ,
73
- hidden_size = config .hidden_size ,
74
- intermediate_size = config .intermediate_size ,
75
- quant_config = quant_config ,
76
- prefix = f"{ prefix } .block_sparse_moe" )
70
+ self .block_sparse_moe = None
71
+ if getattr (config , "num_local_experts" , 0 ) > 0 :
72
+ self .block_sparse_moe = GraniteMoeMoE (
73
+ num_experts = config .num_local_experts ,
74
+ top_k = config .num_experts_per_tok ,
75
+ hidden_size = config .hidden_size ,
76
+ intermediate_size = config .intermediate_size ,
77
+ quant_config = quant_config ,
78
+ prefix = f"{ prefix } .block_sparse_moe" )
77
79
78
80
self .shared_mlp = None if \
79
81
getattr (config , 'shared_intermediate_size' , 0 ) == 0 \
@@ -105,13 +107,19 @@ def forward(
105
107
residual = hidden_states
106
108
hidden_states = self .post_attention_layernorm (hidden_states )
107
109
if self .shared_mlp is None :
108
- hidden_states = self .block_sparse_moe (hidden_states )
110
+ if self .block_sparse_moe is not None :
111
+ hidden_states = self .block_sparse_moe (hidden_states )
112
+ # else: skip
109
113
else :
110
114
# create a copy since block_sparse_moe modifies in-place
111
- moe_hidden_states = hidden_states .clone ()
112
- moe_hidden_states = self .block_sparse_moe (moe_hidden_states )
113
- hidden_states = moe_hidden_states + self .shared_mlp (hidden_states )
114
- del moe_hidden_states
115
+ if self .block_sparse_moe is not None :
116
+ moe_hidden_states = hidden_states .clone ()
117
+ moe_hidden_states = self .block_sparse_moe (moe_hidden_states )
118
+ hidden_states = moe_hidden_states + self .shared_mlp (
119
+ hidden_states )
120
+ del moe_hidden_states
121
+ else :
122
+ hidden_states = self .shared_mlp (hidden_states )
115
123
hidden_states = residual + hidden_states * self .residual_multiplier
116
124
117
125
return hidden_states , residual
@@ -137,13 +145,15 @@ def __init__(
137
145
quant_config = quant_config ,
138
146
prefix = f"{ prefix } .self_attn" )
139
147
140
- self .block_sparse_moe = GraniteMoeMoE (
141
- num_experts = config .num_local_experts ,
142
- top_k = config .num_experts_per_tok ,
143
- hidden_size = config .hidden_size ,
144
- intermediate_size = config .intermediate_size ,
145
- quant_config = quant_config ,
146
- prefix = f"{ prefix } .block_sparse_moe" )
148
+ self .block_sparse_moe = None
149
+ if getattr (config , "num_local_experts" , 0 ) > 0 :
150
+ self .block_sparse_moe = GraniteMoeMoE (
151
+ num_experts = config .num_local_experts ,
152
+ top_k = config .num_experts_per_tok ,
153
+ hidden_size = config .hidden_size ,
154
+ intermediate_size = config .intermediate_size ,
155
+ quant_config = quant_config ,
156
+ prefix = f"{ prefix } .block_sparse_moe" )
147
157
148
158
self .shared_mlp = None if \
149
159
getattr (config , 'shared_intermediate_size' , 0 ) == 0 \
@@ -178,13 +188,19 @@ def forward(
178
188
residual = hidden_states
179
189
hidden_states = self .post_attention_layernorm (hidden_states )
180
190
if self .shared_mlp is None :
181
- hidden_states = self .block_sparse_moe (hidden_states )
191
+ if self .block_sparse_moe is not None :
192
+ hidden_states = self .block_sparse_moe (hidden_states )
193
+ # else: skip
182
194
else :
183
195
# create a copy since block_sparse_moe modifies in-place
184
- moe_hidden_states = hidden_states .clone ()
185
- moe_hidden_states = self .block_sparse_moe (moe_hidden_states )
186
- hidden_states = moe_hidden_states + self .shared_mlp (hidden_states )
187
- del moe_hidden_states
196
+ if self .block_sparse_moe is not None :
197
+ moe_hidden_states = hidden_states .clone ()
198
+ moe_hidden_states = self .block_sparse_moe (moe_hidden_states )
199
+ hidden_states = moe_hidden_states + self .shared_mlp (
200
+ hidden_states )
201
+ del moe_hidden_states
202
+ else :
203
+ hidden_states = self .shared_mlp (hidden_states )
188
204
hidden_states = residual + hidden_states * self .residual_multiplier
189
205
190
206
return hidden_states , residual
0 commit comments