Skip to content

Commit 322f009

Browse files
congcongchen123Chen-zexi
authored andcommitted
[Model] New model support for microsoft/Phi-4-mini-flash-reasoning (vllm-project#20702)
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent 0ee2b2e commit 322f009

22 files changed

+1869
-41
lines changed

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
312312
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
313313
constexpr bool kIsVariableB = true;
314314
constexpr bool kIsVariableC = true;
315-
constexpr bool kHasZ = true;
316315
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
317-
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
318-
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
319-
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
320-
dim3 grid(params.batch, params.dim / kNRows);
321-
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
322-
if (kSmemSize >= 48 * 1024) {
323-
C10_CUDA_CHECK(cudaFuncSetAttribute(
324-
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
325-
}
326-
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
327-
C10_CUDA_KERNEL_LAUNCH_CHECK();
316+
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
317+
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
318+
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
319+
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
320+
dim3 grid(params.batch, params.dim / kNRows);
321+
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
322+
if (kSmemSize >= 48 * 1024) {
323+
C10_CUDA_CHECK(cudaFuncSetAttribute(
324+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
325+
}
326+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
327+
C10_CUDA_KERNEL_LAUNCH_CHECK();
328+
});
328329
});
329330
});
330331
}
@@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
612613

613614
at::Tensor z, out_z;
614615
const bool has_z = z_.has_value();
615-
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
616-
z = z_.value();
617-
TORCH_CHECK(z.scalar_type() == input_type);
618-
TORCH_CHECK(z.is_cuda());
619-
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
620-
if (varlen){
621-
CHECK_SHAPE(z, dim, seqlen);
622-
} else {
623-
CHECK_SHAPE(z, batch_size, dim, seqlen);
616+
if (has_z) {
617+
z = z_.value();
618+
TORCH_CHECK(z.scalar_type() == input_type);
619+
TORCH_CHECK(z.is_cuda());
620+
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
621+
if (varlen){
622+
CHECK_SHAPE(z, dim, seqlen);
623+
} else {
624+
CHECK_SHAPE(z, batch_size, dim, seqlen);
625+
}
626+
627+
out_z = z;
624628
}
625629

626-
out_z = z;
627-
628630
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
629631
at::Tensor out = delta;
630632
TORCH_CHECK(ssm_states.scalar_type() == input_type);
@@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
653655
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
654656
});
655657
}
656-

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ Specified using `--task generate`.
374374
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
375375
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
376376
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
377+
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
377378
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
378379
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
379380
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def check_available_online(
248248
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
249249
trust_remote_code=True,
250250
v0_only=True),
251+
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
252+
trust_remote_code=True,
253+
v0_only=True,
254+
max_model_len=10240),
251255
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
252256
trust_remote_code=True),
253257
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",

tests/models/test_initialization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def _initialize_kv_caches_v1(self, vllm_config):
103103
_initialize_kv_caches_v1), monkeypatch.context() as m):
104104
if model_info.v0_only:
105105
m.setenv("VLLM_USE_V1", "0")
106+
if model_arch == "Phi4FlashForCausalLM":
107+
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
108+
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
106109
LLM(
107110
model_info.default,
108111
tokenizer=model_info.tokenizer,

tests/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,31 @@ def test_bind_kv_cache():
458458
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
459459
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
460460

461+
def test_bind_kv_cache_kv_sharing():
462+
from vllm.attention import Attention
463+
464+
ctx = {
465+
'layers.0.self_attn': Attention(32, 128, 0.1),
466+
'layers.1.self_attn': Attention(32, 128, 0.1),
467+
'layers.2.self_attn': Attention(32, 128, 0.1),
468+
'layers.3.self_attn': Attention(32, 128, 0.1),
469+
}
470+
kv_cache = [
471+
torch.zeros((1, )),
472+
torch.zeros((1, )),
473+
torch.zeros((1, )),
474+
torch.zeros((1, )),
475+
]
476+
shared_kv_cache_layers = {
477+
'layers.2.self_attn': 'layers.1.self_attn',
478+
'layers.3.self_attn': 'layers.0.self_attn'
479+
}
480+
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
481+
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
482+
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
483+
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1]
484+
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0]
485+
461486
def test_bind_kv_cache_non_attention():
462487
from vllm.attention import Attention
463488

vllm/attention/backends/blocksparse_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ def __init__(
308308
kv_sharing_target_layer_name: Optional[str] = None,
309309
) -> None:
310310
if kv_sharing_target_layer_name is not None:
311-
raise NotImplementedError("KV sharing is not supported in V0.")
311+
raise NotImplementedError("KV sharing is not supported in V0 "
312+
"BLOCK_SPARSE_FLASH_ATTN Backend.")
312313
assert blocksparse_params is not None
313314
assert alibi_slopes is None, ValueError(
314315
"Alibi not support for blocksparse flash attention.")

0 commit comments

Comments
 (0)