Skip to content

[Model] New model support for microsoft/Phi-4-mini-flash-reasoning #20702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
constexpr bool kIsVariableB = true;
constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
}
Expand Down Expand Up @@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,

at::Tensor z, out_z;
const bool has_z = z_.has_value();
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
if (varlen){
CHECK_SHAPE(z, dim, seqlen);
} else {
CHECK_SHAPE(z, batch_size, dim, seqlen);
if (has_z) {
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
if (varlen){
CHECK_SHAPE(z, dim, seqlen);
} else {
CHECK_SHAPE(z, batch_size, dim, seqlen);
}

out_z = z;
}

out_z = z;

// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type);
Expand Down Expand Up @@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
});
}

1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ Specified using `--task generate`.
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand Down
4 changes: 4 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def check_available_online(
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True,
v0_only=True),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
max_model_len=10240),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def _initialize_kv_caches_v1(self, vllm_config):
_initialize_kv_caches_v1), monkeypatch.context() as m):
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
if model_arch == "Phi4FlashForCausalLM":
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
Expand Down
25 changes: 25 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,31 @@ def test_bind_kv_cache():
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]

def test_bind_kv_cache_kv_sharing():
from vllm.attention import Attention

ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
]
shared_kv_cache_layers = {
'layers.2.self_attn': 'layers.1.self_attn',
'layers.3.self_attn': 'layers.0.self_attn'
}
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0]

def test_bind_kv_cache_non_attention():
from vllm.attention import Attention

Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def __init__(
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"BLOCK_SPARSE_FLASH_ATTN Backend.")
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
"Alibi not support for blocksparse flash attention.")
Expand Down
Loading