Skip to content

Commit 8d6c4c4

Browse files
renaming
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent 6943cc9 commit 8d6c4c4

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

vllm/model_executor/models/phi3samba.py renamed to vllm/model_executor/models/phi4sambay.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
4949
return x1 * nn.functional.silu(x2)
5050

5151

52-
class SambaMLP(nn.Module):
52+
class SambaYMLP(nn.Module):
5353
"""Gated Linear Unit.
5454
5555
Reference:
@@ -78,7 +78,7 @@ def get_virtual_engine():
7878
forward_context: ForwardContext = get_forward_context()
7979
return forward_context.virtual_engine
8080

81-
class SambaAttention(nn.Module):
81+
class SambaYAttention(nn.Module):
8282
def __init__(self,
8383
config,
8484
layer_idx: Optional[int] = None,
@@ -391,7 +391,7 @@ def forward(
391391
return contextualized_states, yoco_key_values
392392

393393

394-
class SambaDecoderLayer(nn.Module):
394+
class SambaYDecoderLayer(nn.Module):
395395

396396
def __init__(self,
397397
config,
@@ -403,13 +403,13 @@ def __init__(self,
403403
self.config = config
404404
self.layer_idx = layer_idx
405405

406-
self.mlp = SambaMLP(config)
406+
self.mlp = SambaYMLP(config)
407407
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
408408

409409
self.yoco_mb = False
410410
self.yoco_kv = False
411411
self.yoco_cross = False
412-
assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for samba + yoco'
412+
assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for SambaY + yoco'
413413
if layer_idx >= config.num_hidden_layers//2:
414414
self.yoco_mb = True
415415
self.yoco_kv = (layer_idx >= (config.num_hidden_layers//2 +1))
@@ -420,7 +420,7 @@ def __init__(self,
420420
self.attn = Phi3Mamba(config.hidden_size, layer_idx=layer_idx,
421421
yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs)
422422
else:
423-
self.attn = SambaAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn")
423+
self.attn = SambaYAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn")
424424
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
425425

426426
def forward(
@@ -469,7 +469,7 @@ def get_kv_cache(layer_name):
469469
kv_cache = self.kv_cache[forward_context.virtual_engine]
470470
return kv_cache
471471

472-
class SambaModel(nn.Module):
472+
class SambaYModel(nn.Module):
473473

474474
def __init__(
475475
self,
@@ -494,7 +494,7 @@ def __init__(
494494

495495
self.start_layer, self.end_layer, self.layers = make_layers(
496496
config.num_hidden_layers,
497-
lambda prefix: SambaDecoderLayer(config,
497+
lambda prefix: SambaYDecoderLayer(config,
498498
int(prefix.split('.')[-1]),
499499
cache_config,
500500
prefix=prefix),
@@ -590,7 +590,7 @@ def forward(
590590
return hidden_states
591591

592592

593-
class SambaForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
593+
class Phi4MiniFlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
594594

595595
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
596596
config = vllm_config.model_config.hf_config
@@ -603,13 +603,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
603603
# Prefix caching is not supported since there are mamba layers in this
604604
# mode.
605605
assert not cache_config.enable_prefix_caching, \
606-
"Samba currently does not support prefix caching"
606+
"SambaY currently does not support prefix caching"
607607

608608
super().__init__()
609609
self.config = config
610610
self.model_config = vllm_config.model_config
611611
self.scheduler_config = scheduler_config
612-
self.model = SambaModel(
612+
self.model = SambaYModel(
613613
config,
614614
cache_config=cache_config,
615615
prefix=maybe_prefix(prefix, "model")

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@
110110
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
111111
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
112112
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
113-
"SambaForCausalLM": ("phi3samba", "SambaForCausalLM"),
113+
"Phi4MiniFlashForCausalLM": ("phi4sambay", "Phi4MiniFlashForCausalLM"),
114114
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
115115
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
116116
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),

0 commit comments

Comments
 (0)