Skip to content

Commit eca1869

Browse files
dhiaEddineRhaiemyounesbelkadailyasch2JingweiZuo
authored
[MODEL] FalconH1 (vllm-project#18406)
Signed-off-by: dhia.rhaiem <dhia.rhaiem@tii.ae> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Ilyas Chahed <ilyas.chahed@tii.ae> Co-authored-by: Jingwei Zuo <jingwei.zuo@tii.ae>
1 parent 61acfc4 commit eca1869

File tree

5 files changed

+798
-59
lines changed

5 files changed

+798
-59
lines changed

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@ Specified using `--task generate`.
392392
* `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc.
393393
* ✅︎
394394
* ✅︎
395+
- * `FalconH1ForCausalLM`
396+
* Falcon-H1
397+
* `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc.
398+
* ✅︎
399+
* ✅︎
395400
- * `GemmaForCausalLM`
396401
* Gemma
397402
* `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc.

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ def check_available_online(
147147
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
148148
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
149149
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
150+
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
151+
is_available_online=False,
152+
min_transformers_version="4.52.2"),
150153
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
151154
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
152155
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 104 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
@CustomOp.register("mixer2_gated_rms_norm")
3535
class Mixer2RMSNormGated(CustomOp):
3636

37-
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6):
37+
def __init__(self,
38+
full_hidden_size: int,
39+
full_n_groups: int,
40+
use_rms_norm: bool = True,
41+
eps: float = 1e-6):
3842
super().__init__()
3943
self.tp_size = get_tensor_model_parallel_world_size()
4044
self.tp_rank = get_tensor_model_parallel_rank()
@@ -44,11 +48,17 @@ def __init__(self, full_hidden_size, full_n_groups, eps=1e-6):
4448
self.n_groups = full_hidden_size // self.group_size
4549

4650
self.variance_epsilon = eps
47-
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
48-
set_weight_attrs(self.weight,
49-
{"weight_loader": sharded_weight_loader(0)})
50-
assert self.full_hidden_size % self.tp_size== 0,\
51-
"Tensor parallel world size must divide hidden size."
51+
self.use_rms_norm = use_rms_norm
52+
if self.use_rms_norm:
53+
# Register norm weight only if we're actually applying RMSNorm
54+
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
55+
set_weight_attrs(self.weight,
56+
{"weight_loader": sharded_weight_loader(0)})
57+
else:
58+
# Avoid checkpoint mismatch by skipping unused parameter
59+
self.register_parameter("weight", None)
60+
assert (self.full_hidden_size % self.tp_size == 0
61+
), "Tensor parallel world size must divide hidden size."
5262

5363
def forward_native(
5464
self,
@@ -66,6 +76,8 @@ def forward_native(
6676
# the input and then redundantly compute the RMSNorm.
6777
input_dtype = x.dtype
6878
x = x * nn.functional.silu(gate.to(torch.float32))
79+
if not self.use_rms_norm:
80+
return x
6981

7082
if self.n_groups == 1:
7183
if self.tp_size > 1:
@@ -74,7 +86,7 @@ def forward_native(
7486
global_sums = tensor_model_parallel_all_reduce(local_sums)
7587
# Calculate the variance
7688
count = self.tp_size * x.shape[-1]
77-
variance = (global_sums / count)
89+
variance = global_sums / count
7890

7991
else:
8092
variance = x.pow(2).mean(-1, keepdim=True)
@@ -106,6 +118,9 @@ def forward_cuda(
106118
gate: torch.Tensor,
107119
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
108120

121+
if not self.use_rms_norm:
122+
return x * nn.functional.silu(gate.to(torch.float32))
123+
109124
if self.tp_size > 1 or self.n_groups != 1:
110125
return self.forward_native(x, gate)
111126

@@ -124,7 +139,7 @@ def forward_cuda(
124139

125140

126141
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
127-
"""Compute the increase in group numbers to account for
142+
"""Compute the increase in group numbers to account for
128143
replication in order to accompany the head shards."""
129144

130145
# in the case ngoups % tp_size == 0, this will be zero
@@ -182,13 +197,15 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
182197
# seem to handle slices well.
183198
# https://github.com/python/mypy/issues/2410
184199
param.data[
185-
boundary:(boundary + take), # type: ignore[misc]
186-
...] = loaded_weight[loaded_start_idx:( # type: ignore[misc]
187-
loaded_start_idx + take)] # type: ignore[misc]
200+
boundary:(boundary + take),
201+
... # type: ignore[misc]
202+
] = loaded_weight[loaded_start_idx:(loaded_start_idx +
203+
take) # type: ignore[misc]
204+
] # type: ignore[misc]
188205

189206
# move indexing boundaries
190207
boundary += shard_size
191-
loaded_boundary += (full_dim - extra)
208+
loaded_boundary += full_dim - extra
192209

193210
return loader
194211

@@ -206,19 +223,22 @@ class MambaMixer2(CustomOp):
206223
**selective** state spaces)
207224
"""
208225

209-
def __init__(self,
210-
hidden_size: int,
211-
ssm_state_size: int,
212-
conv_kernel_size: int,
213-
intermediate_size: int,
214-
use_conv_bias: bool,
215-
use_bias: bool,
216-
n_groups: int = 1,
217-
num_heads: int = 128,
218-
head_dim: int = 64,
219-
rms_norm_eps: float = 1e-5,
220-
activation="silu",
221-
quant_config: Optional[QuantizationConfig] = None):
226+
def __init__(
227+
self,
228+
hidden_size: int,
229+
ssm_state_size: int,
230+
conv_kernel_size: int,
231+
intermediate_size: int,
232+
use_conv_bias: bool,
233+
use_bias: bool,
234+
n_groups: int = 1,
235+
num_heads: int = 128,
236+
head_dim: int = 64,
237+
rms_norm_eps: float = 1e-5,
238+
activation: str = "silu",
239+
use_rms_norm: bool = True,
240+
quant_config: Optional[QuantizationConfig] = None,
241+
):
222242
super().__init__()
223243

224244
# For TP, the sharding plan is as follows:
@@ -238,17 +258,16 @@ def __init__(self,
238258
self.tp_size = get_tensor_model_parallel_world_size()
239259
tp_rank = get_tensor_model_parallel_rank()
240260

241-
assert num_heads % self.tp_size == 0, \
242-
"Tensor parallel world size must divide num heads."
261+
assert (num_heads % self.tp_size == 0
262+
), "Tensor parallel world size must divide num heads."
243263

244-
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
245-
(
246-
"If tensor parallel world size does not divide num_heads, "
247-
"then num_groups must equal 1."
248-
)
264+
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
265+
"If tensor parallel world size does not divide num_heads, "
266+
"then num_groups must equal 1.")
249267

250-
assert self.tp_size == 1 or quant_config is None, \
251-
"Tensor parallel currently not supported for quantized models."
268+
assert (
269+
self.tp_size == 1 or quant_config is None
270+
), "Tensor parallel currently not supported for quantized models."
252271

253272
self.ssm_state_size = ssm_state_size
254273
self.activation = activation
@@ -265,8 +284,7 @@ def __init__(self,
265284
self.n_groups = n_groups + extra_groups_for_head_shards(
266285
n_groups, self.tp_size)
267286

268-
self.conv_dim = (intermediate_size +
269-
2 * self.n_groups * ssm_state_size)
287+
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
270288
self.conv1d = ColumnParallelLinear(
271289
input_size=conv_kernel_size,
272290
output_size=self.conv_dim,
@@ -279,11 +297,12 @@ def __init__(self,
279297
# doesn't allow to override it
280298
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
281299

282-
self.in_proj = ColumnParallelLinear(input_size=hidden_size,
283-
output_size=intermediate_size +
284-
self.conv_dim + self.num_heads,
285-
bias=use_bias,
286-
quant_config=quant_config)
300+
self.in_proj = ColumnParallelLinear(
301+
input_size=hidden_size,
302+
output_size=intermediate_size + self.conv_dim + self.num_heads,
303+
bias=use_bias,
304+
quant_config=quant_config,
305+
)
287306

288307
# - because in_proj is a concatenation of 3 weights, we
289308
# need to interleave them before sharding
@@ -305,7 +324,8 @@ def __init__(self,
305324
# - ditto for the otther two weights below
306325
delattr(self.conv1d.bias, "weight_loader")
307326
set_weight_attrs(
308-
self.conv1d.bias, {
327+
self.conv1d.bias,
328+
{
309329
"weight_loader":
310330
mamba_v2_sharded_weight_loader(
311331
[
@@ -316,18 +336,25 @@ def __init__(self,
316336
self.tp_size,
317337
tp_rank,
318338
)
319-
})
339+
},
340+
)
320341

321342
delattr(self.conv1d.weight, "weight_loader")
322343
set_weight_attrs(
323-
self.conv1d.weight, {
344+
self.conv1d.weight,
345+
{
324346
"weight_loader":
325-
mamba_v2_sharded_weight_loader([
326-
intermediate_settings,
327-
group_shard_settings,
328-
group_shard_settings,
329-
], self.tp_size, tp_rank)
330-
})
347+
mamba_v2_sharded_weight_loader(
348+
[
349+
intermediate_settings,
350+
group_shard_settings,
351+
group_shard_settings,
352+
],
353+
self.tp_size,
354+
tp_rank,
355+
)
356+
},
357+
)
331358

332359
if quant_config is None:
333360
# - quant layers do not have a weight loader
@@ -345,8 +372,10 @@ def __init__(self,
345372
head_setings, # for dt
346373
],
347374
self.tp_size,
348-
tp_rank)
349-
})
375+
tp_rank,
376+
)
377+
},
378+
)
350379

351380
# - these are TPed by heads to reduce the size of the
352381
# temporal shape
@@ -357,6 +386,7 @@ def __init__(self,
357386
))
358387
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
359388
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
389+
self.use_rms_norm = use_rms_norm
360390

361391
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
362392
a_weight_loader = composed_weight_loader(
@@ -365,25 +395,33 @@ def __init__(self,
365395
set_weight_attrs(self.dt_bias,
366396
{"weight_loader": sharded_weight_loader(0)})
367397

368-
self.out_proj = RowParallelLinear(intermediate_size,
369-
hidden_size,
370-
bias=use_bias,
371-
input_is_parallel=True,
372-
quant_config=quant_config)
398+
self.out_proj = RowParallelLinear(
399+
intermediate_size,
400+
hidden_size,
401+
bias=use_bias,
402+
input_is_parallel=True,
403+
quant_config=quant_config,
404+
)
373405

374406
self.norm = Mixer2RMSNormGated(intermediate_size,
375407
n_groups,
408+
self.use_rms_norm,
376409
eps=rms_norm_eps)
377410

378-
def forward_native(self, hidden_states: torch.Tensor,
379-
conv_state: torch.Tensor, ssm_state: torch.Tensor):
411+
def forward_native(
412+
self,
413+
hidden_states: torch.Tensor,
414+
conv_state: torch.Tensor,
415+
ssm_state: torch.Tensor,
416+
):
380417
pass
381418

382419
def forward_cuda(
383420
self,
384421
hidden_states: torch.Tensor,
385422
mamba_cache_params: MambaCacheParams,
386423
mamba2_metadata: Mamba2Metadata,
424+
mup_vector: Optional[torch.Tensor] = None,
387425
):
388426
# mamba2_metadata contains metadata necessary for the mamba2 triton
389427
# kernels to operate in continuous batching and in chunked prefill
@@ -401,6 +439,10 @@ def forward_cuda(
401439

402440
# 1. Gated MLP's linear projection
403441
projected_states, _ = self.in_proj(hidden_states)
442+
443+
if mup_vector is not None:
444+
projected_states = projected_states * mup_vector
445+
404446
gate, hidden_states_B_C, dt = torch.split(
405447
projected_states,
406448
[
@@ -561,6 +603,9 @@ def forward_cuda(
561603
hidden_states = torch.vstack(ssd_output_list)
562604

563605
# 4. gated MLP
606+
# GatedRMSNorm internally applying SiLU to the gate
607+
# SiLU is applied internally before normalization, unlike standard
608+
# norm usage
564609
hidden_states = self.norm(hidden_states, gate)
565610

566611
# 5. Final linear projection

0 commit comments

Comments
 (0)