Skip to content

Commit bd81a61

Browse files
renames
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent 8d6c4c4 commit bd81a61

File tree

2 files changed

+5
-15
lines changed

2 files changed

+5
-15
lines changed

vllm/model_executor/models/phi4sambay.py renamed to vllm/model_executor/models/phi4flash.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
class SwiGLUActivation(nn.Module):
4646

4747
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
48-
# print(f"x1 shape: {x1.shape}, x2 shape: {x2.shape}")
4948
return x1 * nn.functional.silu(x2)
5049

5150

@@ -175,7 +174,7 @@ def forward(
175174
return self.out_proj(attn_output)
176175

177176

178-
class Phi3Mamba(nn.Module):
177+
class Phi4Mamba(nn.Module):
179178
def __init__(
180179
self,
181180
d_model,
@@ -250,15 +249,6 @@ def __init__(
250249
params_dtype=dtype,
251250
)
252251

253-
# # S4D real initialization
254-
# A = repeat(
255-
# torch.arange(1, self.d_state + 1, dtype=torch.float32),
256-
# "n -> d n",
257-
# d=self.d_inner,
258-
# ).contiguous()
259-
# A_log = torch.log(A) # Keep A_log in fp32
260-
# self.A_log = nn.Parameter(A_log)
261-
262252
# # D "skip" parameter
263253
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
264254
self.A = nn.Parameter(
@@ -417,7 +407,7 @@ def __init__(self,
417407
self.use_mamba = config.mb_per_layer > 0 and layer_idx % config.mb_per_layer == 0
418408
if self.use_mamba:
419409
factory_kwargs = {"dtype": None}
420-
self.attn = Phi3Mamba(config.hidden_size, layer_idx=layer_idx,
410+
self.attn = Phi4Mamba(config.hidden_size, layer_idx=layer_idx,
421411
yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs)
422412
else:
423413
self.attn = SambaYAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn")
@@ -590,7 +580,7 @@ def forward(
590580
return hidden_states
591581

592582

593-
class Phi4MiniFlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
583+
class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
594584

595585
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
596586
config = vllm_config.model_config.hf_config
@@ -603,7 +593,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
603593
# Prefix caching is not supported since there are mamba layers in this
604594
# mode.
605595
assert not cache_config.enable_prefix_caching, \
606-
"SambaY currently does not support prefix caching"
596+
"Phi4flash currently does not support prefix caching"
607597

608598
super().__init__()
609599
self.config = config

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@
110110
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
111111
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
112112
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
113-
"Phi4MiniFlashForCausalLM": ("phi4sambay", "Phi4MiniFlashForCausalLM"),
113+
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
114114
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
115115
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
116116
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),

0 commit comments

Comments
 (0)