45
45
class SwiGLUActivation (nn .Module ):
46
46
47
47
def forward (self , x1 : torch .Tensor , x2 : torch .Tensor ) -> torch .Tensor :
48
- # print(f"x1 shape: {x1.shape}, x2 shape: {x2.shape}")
49
48
return x1 * nn .functional .silu (x2 )
50
49
51
50
@@ -175,7 +174,7 @@ def forward(
175
174
return self .out_proj (attn_output )
176
175
177
176
178
- class Phi3Mamba (nn .Module ):
177
+ class Phi4Mamba (nn .Module ):
179
178
def __init__ (
180
179
self ,
181
180
d_model ,
@@ -250,15 +249,6 @@ def __init__(
250
249
params_dtype = dtype ,
251
250
)
252
251
253
- # # S4D real initialization
254
- # A = repeat(
255
- # torch.arange(1, self.d_state + 1, dtype=torch.float32),
256
- # "n -> d n",
257
- # d=self.d_inner,
258
- # ).contiguous()
259
- # A_log = torch.log(A) # Keep A_log in fp32
260
- # self.A_log = nn.Parameter(A_log)
261
-
262
252
# # D "skip" parameter
263
253
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
264
254
self .A = nn .Parameter (
@@ -417,7 +407,7 @@ def __init__(self,
417
407
self .use_mamba = config .mb_per_layer > 0 and layer_idx % config .mb_per_layer == 0
418
408
if self .use_mamba :
419
409
factory_kwargs = {"dtype" : None }
420
- self .attn = Phi3Mamba (config .hidden_size , layer_idx = layer_idx ,
410
+ self .attn = Phi4Mamba (config .hidden_size , layer_idx = layer_idx ,
421
411
yoco_cross = self .yoco_cross , yoco_kv = self .yoco_mb , ** factory_kwargs )
422
412
else :
423
413
self .attn = SambaYAttention (config , layer_idx = layer_idx , yoco_cross = self .yoco_cross , cache_config = cache_config , prefix = f"{ prefix } .self_attn" )
@@ -590,7 +580,7 @@ def forward(
590
580
return hidden_states
591
581
592
582
593
- class Phi4MiniFlashForCausalLM (nn .Module , HasInnerState , IsHybrid , SupportsV0Only ):
583
+ class Phi4FlashForCausalLM (nn .Module , HasInnerState , IsHybrid , SupportsV0Only ):
594
584
595
585
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
596
586
config = vllm_config .model_config .hf_config
@@ -603,7 +593,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
603
593
# Prefix caching is not supported since there are mamba layers in this
604
594
# mode.
605
595
assert not cache_config .enable_prefix_caching , \
606
- "SambaY currently does not support prefix caching"
596
+ "Phi4flash currently does not support prefix caching"
607
597
608
598
super ().__init__ ()
609
599
self .config = config
0 commit comments