@@ -49,7 +49,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
49
49
return x1 * nn .functional .silu (x2 )
50
50
51
51
52
- class SambaMLP (nn .Module ):
52
+ class SambaYMLP (nn .Module ):
53
53
"""Gated Linear Unit.
54
54
55
55
Reference:
@@ -78,7 +78,7 @@ def get_virtual_engine():
78
78
forward_context : ForwardContext = get_forward_context ()
79
79
return forward_context .virtual_engine
80
80
81
- class SambaAttention (nn .Module ):
81
+ class SambaYAttention (nn .Module ):
82
82
def __init__ (self ,
83
83
config ,
84
84
layer_idx : Optional [int ] = None ,
@@ -391,7 +391,7 @@ def forward(
391
391
return contextualized_states , yoco_key_values
392
392
393
393
394
- class SambaDecoderLayer (nn .Module ):
394
+ class SambaYDecoderLayer (nn .Module ):
395
395
396
396
def __init__ (self ,
397
397
config ,
@@ -403,13 +403,13 @@ def __init__(self,
403
403
self .config = config
404
404
self .layer_idx = layer_idx
405
405
406
- self .mlp = SambaMLP (config )
406
+ self .mlp = SambaYMLP (config )
407
407
self .input_layernorm = nn .LayerNorm (config .hidden_size , eps = config .layer_norm_eps )
408
408
409
409
self .yoco_mb = False
410
410
self .yoco_kv = False
411
411
self .yoco_cross = False
412
- assert config .num_hidden_layers % 4 == 0 , 'n_layer should be divisible by 4 for samba + yoco'
412
+ assert config .num_hidden_layers % 4 == 0 , 'n_layer should be divisible by 4 for SambaY + yoco'
413
413
if layer_idx >= config .num_hidden_layers // 2 :
414
414
self .yoco_mb = True
415
415
self .yoco_kv = (layer_idx >= (config .num_hidden_layers // 2 + 1 ))
@@ -420,7 +420,7 @@ def __init__(self,
420
420
self .attn = Phi3Mamba (config .hidden_size , layer_idx = layer_idx ,
421
421
yoco_cross = self .yoco_cross , yoco_kv = self .yoco_mb , ** factory_kwargs )
422
422
else :
423
- self .attn = SambaAttention (config , layer_idx = layer_idx , yoco_cross = self .yoco_cross , cache_config = cache_config , prefix = f"{ prefix } .self_attn" )
423
+ self .attn = SambaYAttention (config , layer_idx = layer_idx , yoco_cross = self .yoco_cross , cache_config = cache_config , prefix = f"{ prefix } .self_attn" )
424
424
self .post_attention_layernorm = nn .LayerNorm (config .hidden_size , eps = config .layer_norm_eps )
425
425
426
426
def forward (
@@ -469,7 +469,7 @@ def get_kv_cache(layer_name):
469
469
kv_cache = self .kv_cache [forward_context .virtual_engine ]
470
470
return kv_cache
471
471
472
- class SambaModel (nn .Module ):
472
+ class SambaYModel (nn .Module ):
473
473
474
474
def __init__ (
475
475
self ,
@@ -494,7 +494,7 @@ def __init__(
494
494
495
495
self .start_layer , self .end_layer , self .layers = make_layers (
496
496
config .num_hidden_layers ,
497
- lambda prefix : SambaDecoderLayer (config ,
497
+ lambda prefix : SambaYDecoderLayer (config ,
498
498
int (prefix .split ('.' )[- 1 ]),
499
499
cache_config ,
500
500
prefix = prefix ),
@@ -590,7 +590,7 @@ def forward(
590
590
return hidden_states
591
591
592
592
593
- class SambaForCausalLM (nn .Module , HasInnerState , IsHybrid , SupportsV0Only ):
593
+ class Phi4MiniFlashForCausalLM (nn .Module , HasInnerState , IsHybrid , SupportsV0Only ):
594
594
595
595
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
596
596
config = vllm_config .model_config .hf_config
@@ -603,13 +603,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
603
603
# Prefix caching is not supported since there are mamba layers in this
604
604
# mode.
605
605
assert not cache_config .enable_prefix_caching , \
606
- "Samba currently does not support prefix caching"
606
+ "SambaY currently does not support prefix caching"
607
607
608
608
super ().__init__ ()
609
609
self .config = config
610
610
self .model_config = vllm_config .model_config
611
611
self .scheduler_config = scheduler_config
612
- self .model = SambaModel (
612
+ self .model = SambaYModel (
613
613
config ,
614
614
cache_config = cache_config ,
615
615
prefix = maybe_prefix (prefix , "model" )
0 commit comments