34
34
@CustomOp .register ("mixer2_gated_rms_norm" )
35
35
class Mixer2RMSNormGated (CustomOp ):
36
36
37
- def __init__ (self , full_hidden_size , full_n_groups , eps = 1e-6 ):
37
+ def __init__ (self ,
38
+ full_hidden_size : int ,
39
+ full_n_groups : int ,
40
+ use_rms_norm : bool = True ,
41
+ eps : float = 1e-6 ):
38
42
super ().__init__ ()
39
43
self .tp_size = get_tensor_model_parallel_world_size ()
40
44
self .tp_rank = get_tensor_model_parallel_rank ()
@@ -44,11 +48,17 @@ def __init__(self, full_hidden_size, full_n_groups, eps=1e-6):
44
48
self .n_groups = full_hidden_size // self .group_size
45
49
46
50
self .variance_epsilon = eps
47
- self .weight = nn .Parameter (torch .ones (self .per_rank_hidden_size ))
48
- set_weight_attrs (self .weight ,
49
- {"weight_loader" : sharded_weight_loader (0 )})
50
- assert self .full_hidden_size % self .tp_size == 0 ,\
51
- "Tensor parallel world size must divide hidden size."
51
+ self .use_rms_norm = use_rms_norm
52
+ if self .use_rms_norm :
53
+ # Register norm weight only if we're actually applying RMSNorm
54
+ self .weight = nn .Parameter (torch .ones (self .per_rank_hidden_size ))
55
+ set_weight_attrs (self .weight ,
56
+ {"weight_loader" : sharded_weight_loader (0 )})
57
+ else :
58
+ # Avoid checkpoint mismatch by skipping unused parameter
59
+ self .register_parameter ("weight" , None )
60
+ assert (self .full_hidden_size % self .tp_size == 0
61
+ ), "Tensor parallel world size must divide hidden size."
52
62
53
63
def forward_native (
54
64
self ,
@@ -66,6 +76,8 @@ def forward_native(
66
76
# the input and then redundantly compute the RMSNorm.
67
77
input_dtype = x .dtype
68
78
x = x * nn .functional .silu (gate .to (torch .float32 ))
79
+ if not self .use_rms_norm :
80
+ return x
69
81
70
82
if self .n_groups == 1 :
71
83
if self .tp_size > 1 :
@@ -74,7 +86,7 @@ def forward_native(
74
86
global_sums = tensor_model_parallel_all_reduce (local_sums )
75
87
# Calculate the variance
76
88
count = self .tp_size * x .shape [- 1 ]
77
- variance = ( global_sums / count )
89
+ variance = global_sums / count
78
90
79
91
else :
80
92
variance = x .pow (2 ).mean (- 1 , keepdim = True )
@@ -106,6 +118,9 @@ def forward_cuda(
106
118
gate : torch .Tensor ,
107
119
) -> Union [torch .Tensor , tuple [torch .Tensor , torch .Tensor ]]:
108
120
121
+ if not self .use_rms_norm :
122
+ return x * nn .functional .silu (gate .to (torch .float32 ))
123
+
109
124
if self .tp_size > 1 or self .n_groups != 1 :
110
125
return self .forward_native (x , gate )
111
126
@@ -124,7 +139,7 @@ def forward_cuda(
124
139
125
140
126
141
def extra_groups_for_head_shards (ngroups : int , tp_size : int ):
127
- """Compute the increase in group numbers to account for
142
+ """Compute the increase in group numbers to account for
128
143
replication in order to accompany the head shards."""
129
144
130
145
# in the case ngoups % tp_size == 0, this will be zero
@@ -182,13 +197,15 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
182
197
# seem to handle slices well.
183
198
# https://github.com/python/mypy/issues/2410
184
199
param .data [
185
- boundary :(boundary + take ), # type: ignore[misc]
186
- ...] = loaded_weight [loaded_start_idx :( # type: ignore[misc]
187
- loaded_start_idx + take )] # type: ignore[misc]
200
+ boundary :(boundary + take ),
201
+ ... # type: ignore[misc]
202
+ ] = loaded_weight [loaded_start_idx :(loaded_start_idx +
203
+ take ) # type: ignore[misc]
204
+ ] # type: ignore[misc]
188
205
189
206
# move indexing boundaries
190
207
boundary += shard_size
191
- loaded_boundary += ( full_dim - extra )
208
+ loaded_boundary += full_dim - extra
192
209
193
210
return loader
194
211
@@ -206,19 +223,22 @@ class MambaMixer2(CustomOp):
206
223
**selective** state spaces)
207
224
"""
208
225
209
- def __init__ (self ,
210
- hidden_size : int ,
211
- ssm_state_size : int ,
212
- conv_kernel_size : int ,
213
- intermediate_size : int ,
214
- use_conv_bias : bool ,
215
- use_bias : bool ,
216
- n_groups : int = 1 ,
217
- num_heads : int = 128 ,
218
- head_dim : int = 64 ,
219
- rms_norm_eps : float = 1e-5 ,
220
- activation = "silu" ,
221
- quant_config : Optional [QuantizationConfig ] = None ):
226
+ def __init__ (
227
+ self ,
228
+ hidden_size : int ,
229
+ ssm_state_size : int ,
230
+ conv_kernel_size : int ,
231
+ intermediate_size : int ,
232
+ use_conv_bias : bool ,
233
+ use_bias : bool ,
234
+ n_groups : int = 1 ,
235
+ num_heads : int = 128 ,
236
+ head_dim : int = 64 ,
237
+ rms_norm_eps : float = 1e-5 ,
238
+ activation : str = "silu" ,
239
+ use_rms_norm : bool = True ,
240
+ quant_config : Optional [QuantizationConfig ] = None ,
241
+ ):
222
242
super ().__init__ ()
223
243
224
244
# For TP, the sharding plan is as follows:
@@ -238,17 +258,16 @@ def __init__(self,
238
258
self .tp_size = get_tensor_model_parallel_world_size ()
239
259
tp_rank = get_tensor_model_parallel_rank ()
240
260
241
- assert num_heads % self .tp_size == 0 , \
242
- "Tensor parallel world size must divide num heads."
261
+ assert ( num_heads % self .tp_size == 0
262
+ ), "Tensor parallel world size must divide num heads."
243
263
244
- assert (n_groups % self .tp_size ) == 0 or n_groups == 1 , \
245
- (
246
- "If tensor parallel world size does not divide num_heads, "
247
- "then num_groups must equal 1."
248
- )
264
+ assert (n_groups % self .tp_size ) == 0 or n_groups == 1 , (
265
+ "If tensor parallel world size does not divide num_heads, "
266
+ "then num_groups must equal 1." )
249
267
250
- assert self .tp_size == 1 or quant_config is None , \
251
- "Tensor parallel currently not supported for quantized models."
268
+ assert (
269
+ self .tp_size == 1 or quant_config is None
270
+ ), "Tensor parallel currently not supported for quantized models."
252
271
253
272
self .ssm_state_size = ssm_state_size
254
273
self .activation = activation
@@ -265,8 +284,7 @@ def __init__(self,
265
284
self .n_groups = n_groups + extra_groups_for_head_shards (
266
285
n_groups , self .tp_size )
267
286
268
- self .conv_dim = (intermediate_size +
269
- 2 * self .n_groups * ssm_state_size )
287
+ self .conv_dim = intermediate_size + 2 * self .n_groups * ssm_state_size
270
288
self .conv1d = ColumnParallelLinear (
271
289
input_size = conv_kernel_size ,
272
290
output_size = self .conv_dim ,
@@ -279,11 +297,12 @@ def __init__(self,
279
297
# doesn't allow to override it
280
298
self .conv1d .weight .data = self .conv1d .weight .data .unsqueeze (1 )
281
299
282
- self .in_proj = ColumnParallelLinear (input_size = hidden_size ,
283
- output_size = intermediate_size +
284
- self .conv_dim + self .num_heads ,
285
- bias = use_bias ,
286
- quant_config = quant_config )
300
+ self .in_proj = ColumnParallelLinear (
301
+ input_size = hidden_size ,
302
+ output_size = intermediate_size + self .conv_dim + self .num_heads ,
303
+ bias = use_bias ,
304
+ quant_config = quant_config ,
305
+ )
287
306
288
307
# - because in_proj is a concatenation of 3 weights, we
289
308
# need to interleave them before sharding
@@ -305,7 +324,8 @@ def __init__(self,
305
324
# - ditto for the otther two weights below
306
325
delattr (self .conv1d .bias , "weight_loader" )
307
326
set_weight_attrs (
308
- self .conv1d .bias , {
327
+ self .conv1d .bias ,
328
+ {
309
329
"weight_loader" :
310
330
mamba_v2_sharded_weight_loader (
311
331
[
@@ -316,18 +336,25 @@ def __init__(self,
316
336
self .tp_size ,
317
337
tp_rank ,
318
338
)
319
- })
339
+ },
340
+ )
320
341
321
342
delattr (self .conv1d .weight , "weight_loader" )
322
343
set_weight_attrs (
323
- self .conv1d .weight , {
344
+ self .conv1d .weight ,
345
+ {
324
346
"weight_loader" :
325
- mamba_v2_sharded_weight_loader ([
326
- intermediate_settings ,
327
- group_shard_settings ,
328
- group_shard_settings ,
329
- ], self .tp_size , tp_rank )
330
- })
347
+ mamba_v2_sharded_weight_loader (
348
+ [
349
+ intermediate_settings ,
350
+ group_shard_settings ,
351
+ group_shard_settings ,
352
+ ],
353
+ self .tp_size ,
354
+ tp_rank ,
355
+ )
356
+ },
357
+ )
331
358
332
359
if quant_config is None :
333
360
# - quant layers do not have a weight loader
@@ -345,8 +372,10 @@ def __init__(self,
345
372
head_setings , # for dt
346
373
],
347
374
self .tp_size ,
348
- tp_rank )
349
- })
375
+ tp_rank ,
376
+ )
377
+ },
378
+ )
350
379
351
380
# - these are TPed by heads to reduce the size of the
352
381
# temporal shape
@@ -357,6 +386,7 @@ def __init__(self,
357
386
))
358
387
self .D = nn .Parameter (torch .ones (num_heads // self .tp_size ))
359
388
self .dt_bias = nn .Parameter (torch .ones (num_heads // self .tp_size ))
389
+ self .use_rms_norm = use_rms_norm
360
390
361
391
set_weight_attrs (self .D , {"weight_loader" : sharded_weight_loader (0 )})
362
392
a_weight_loader = composed_weight_loader (
@@ -365,25 +395,33 @@ def __init__(self,
365
395
set_weight_attrs (self .dt_bias ,
366
396
{"weight_loader" : sharded_weight_loader (0 )})
367
397
368
- self .out_proj = RowParallelLinear (intermediate_size ,
369
- hidden_size ,
370
- bias = use_bias ,
371
- input_is_parallel = True ,
372
- quant_config = quant_config )
398
+ self .out_proj = RowParallelLinear (
399
+ intermediate_size ,
400
+ hidden_size ,
401
+ bias = use_bias ,
402
+ input_is_parallel = True ,
403
+ quant_config = quant_config ,
404
+ )
373
405
374
406
self .norm = Mixer2RMSNormGated (intermediate_size ,
375
407
n_groups ,
408
+ self .use_rms_norm ,
376
409
eps = rms_norm_eps )
377
410
378
- def forward_native (self , hidden_states : torch .Tensor ,
379
- conv_state : torch .Tensor , ssm_state : torch .Tensor ):
411
+ def forward_native (
412
+ self ,
413
+ hidden_states : torch .Tensor ,
414
+ conv_state : torch .Tensor ,
415
+ ssm_state : torch .Tensor ,
416
+ ):
380
417
pass
381
418
382
419
def forward_cuda (
383
420
self ,
384
421
hidden_states : torch .Tensor ,
385
422
mamba_cache_params : MambaCacheParams ,
386
423
mamba2_metadata : Mamba2Metadata ,
424
+ mup_vector : Optional [torch .Tensor ] = None ,
387
425
):
388
426
# mamba2_metadata contains metadata necessary for the mamba2 triton
389
427
# kernels to operate in continuous batching and in chunked prefill
@@ -401,6 +439,10 @@ def forward_cuda(
401
439
402
440
# 1. Gated MLP's linear projection
403
441
projected_states , _ = self .in_proj (hidden_states )
442
+
443
+ if mup_vector is not None :
444
+ projected_states = projected_states * mup_vector
445
+
404
446
gate , hidden_states_B_C , dt = torch .split (
405
447
projected_states ,
406
448
[
@@ -561,6 +603,9 @@ def forward_cuda(
561
603
hidden_states = torch .vstack (ssd_output_list )
562
604
563
605
# 4. gated MLP
606
+ # GatedRMSNorm internally applying SiLU to the gate
607
+ # SiLU is applied internally before normalization, unlike standard
608
+ # norm usage
564
609
hidden_states = self .norm (hidden_states , gate )
565
610
566
611
# 5. Final linear projection
0 commit comments