Skip to content

Commit 7661e92

Browse files
authored
[Model] Optimize nemotron_h implementation (#19249)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent f168b85 commit 7661e92

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

vllm/model_executor/models/nemotron_h.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
23

34
# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
45
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
@@ -29,7 +30,7 @@
2930
from vllm.forward_context import get_forward_context
3031
from vllm.model_executor.layers.activation import ReLUSquaredActivation
3132
from vllm.model_executor.layers.layernorm import RMSNorm
32-
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
33+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3334
QKVParallelLinear,
3435
RowParallelLinear)
3536
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -63,19 +64,22 @@ def __init__(
6364
config: NemotronHConfig,
6465
quant_config: Optional[QuantizationConfig] = None,
6566
bias: bool = False,
67+
prefix: str = "",
6668
) -> None:
6769
super().__init__()
68-
self.up_proj = MergedColumnParallelLinear(
70+
self.up_proj = ColumnParallelLinear(
6971
input_size=config.hidden_size,
70-
output_sizes=[config.intermediate_size],
72+
output_size=config.intermediate_size,
7173
bias=bias,
7274
quant_config=quant_config,
75+
prefix=f"{prefix}.up_proj",
7376
)
7477
self.down_proj = RowParallelLinear(
7578
input_size=config.intermediate_size,
7679
output_size=config.hidden_size,
7780
bias=bias,
7881
quant_config=quant_config,
82+
prefix=f"{prefix}.down_proj",
7983
)
8084
self.act_fn = ReLUSquaredActivation()
8185

@@ -99,9 +103,12 @@ def __init__(
99103
super().__init__()
100104
self.config = config
101105

102-
self.mixer = NemotronHMLP(config,
103-
quant_config=quant_config,
104-
bias=config.mlp_bias)
106+
self.mixer = NemotronHMLP(
107+
config,
108+
quant_config=quant_config,
109+
bias=config.mlp_bias,
110+
prefix=f"{prefix}.mixer",
111+
)
105112

106113
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
107114

@@ -207,12 +214,14 @@ def __init__(
207214
self.total_num_kv_heads,
208215
bias=False,
209216
quant_config=quant_config,
217+
prefix=f"{prefix}.qkv_proj",
210218
)
211219
self.o_proj = RowParallelLinear(
212220
self.total_num_heads * self.head_dim,
213221
config.hidden_size,
214222
bias=False,
215223
quant_config=quant_config,
224+
prefix=f"{prefix}.o_proj",
216225
)
217226

218227
self.attn = Attention(
@@ -253,7 +262,7 @@ def __init__(
253262
layer_idx,
254263
cache_config,
255264
quant_config,
256-
prefix,
265+
prefix=f"{prefix}.mixer",
257266
)
258267

259268
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -435,7 +444,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
435444
"k_proj",
436445
"v_proj",
437446
],
438-
"gate_up_proj": ["up_proj", "down_proj"]
439447
}
440448

441449
# LoRA specific attributes

0 commit comments

Comments
 (0)