1
1
# SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2
3
3
4
# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
4
5
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
29
30
from vllm .forward_context import get_forward_context
30
31
from vllm .model_executor .layers .activation import ReLUSquaredActivation
31
32
from vllm .model_executor .layers .layernorm import RMSNorm
32
- from vllm .model_executor .layers .linear import (MergedColumnParallelLinear ,
33
+ from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
33
34
QKVParallelLinear ,
34
35
RowParallelLinear )
35
36
from vllm .model_executor .layers .logits_processor import LogitsProcessor
@@ -63,19 +64,22 @@ def __init__(
63
64
config : NemotronHConfig ,
64
65
quant_config : Optional [QuantizationConfig ] = None ,
65
66
bias : bool = False ,
67
+ prefix : str = "" ,
66
68
) -> None :
67
69
super ().__init__ ()
68
- self .up_proj = MergedColumnParallelLinear (
70
+ self .up_proj = ColumnParallelLinear (
69
71
input_size = config .hidden_size ,
70
- output_sizes = [ config .intermediate_size ] ,
72
+ output_size = config .intermediate_size ,
71
73
bias = bias ,
72
74
quant_config = quant_config ,
75
+ prefix = f"{ prefix } .up_proj" ,
73
76
)
74
77
self .down_proj = RowParallelLinear (
75
78
input_size = config .intermediate_size ,
76
79
output_size = config .hidden_size ,
77
80
bias = bias ,
78
81
quant_config = quant_config ,
82
+ prefix = f"{ prefix } .down_proj" ,
79
83
)
80
84
self .act_fn = ReLUSquaredActivation ()
81
85
@@ -99,9 +103,12 @@ def __init__(
99
103
super ().__init__ ()
100
104
self .config = config
101
105
102
- self .mixer = NemotronHMLP (config ,
103
- quant_config = quant_config ,
104
- bias = config .mlp_bias )
106
+ self .mixer = NemotronHMLP (
107
+ config ,
108
+ quant_config = quant_config ,
109
+ bias = config .mlp_bias ,
110
+ prefix = f"{ prefix } .mixer" ,
111
+ )
105
112
106
113
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
107
114
@@ -207,12 +214,14 @@ def __init__(
207
214
self .total_num_kv_heads ,
208
215
bias = False ,
209
216
quant_config = quant_config ,
217
+ prefix = f"{ prefix } .qkv_proj" ,
210
218
)
211
219
self .o_proj = RowParallelLinear (
212
220
self .total_num_heads * self .head_dim ,
213
221
config .hidden_size ,
214
222
bias = False ,
215
223
quant_config = quant_config ,
224
+ prefix = f"{ prefix } .o_proj" ,
216
225
)
217
226
218
227
self .attn = Attention (
@@ -253,7 +262,7 @@ def __init__(
253
262
layer_idx ,
254
263
cache_config ,
255
264
quant_config ,
256
- prefix ,
265
+ prefix = f" { prefix } .mixer" ,
257
266
)
258
267
259
268
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -435,7 +444,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
435
444
"k_proj" ,
436
445
"v_proj" ,
437
446
],
438
- "gate_up_proj" : ["up_proj" , "down_proj" ]
439
447
}
440
448
441
449
# LoRA specific attributes
0 commit comments