Skip to content

Commit 7394744

Browse files
jeejeeleepy-andy-c
authored andcommitted
[Quantization] add BNB for MixtralForCausalLM (vllm-project#20893)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 3a1ed2e commit 7394744

File tree

7 files changed

+128
-20
lines changed

7 files changed

+128
-20
lines changed

vllm/model_executor/model_loader/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,12 @@ def get_model_architecture(
227227
# Special handling for quantized Mixtral.
228228
# FIXME(woosuk): This is a temporary hack.
229229
mixtral_supported = [
230-
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
230+
"fp8",
231+
"compressed-tensors",
232+
"gptq_marlin",
233+
"awq_marlin",
234+
"quark",
235+
"bitsandbytes",
231236
]
232237

233238
vllm_supported_archs = ModelRegistry.get_supported_archs()

vllm/model_executor/models/granitemoe.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@
4545
from vllm.model_executor.layers.rotary_embedding import get_rope
4646
from vllm.model_executor.layers.vocab_parallel_embedding import (
4747
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48+
from vllm.model_executor.model_loader.weight_utils import (
49+
default_weight_loader, maybe_remap_kv_scale_name)
4850
from vllm.model_executor.sampling_metadata import SamplingMetadata
4951
from vllm.sequence import IntermediateTensors
5052

51-
from . import mixtral
5253
from .interfaces import SupportsLoRA, SupportsPP
53-
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
54+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_layers,
55+
maybe_prefix)
5456

5557

5658
class GraniteMoeMoE(nn.Module):
@@ -307,6 +309,103 @@ def forward(
307309
hidden_states = self.norm(hidden_states)
308310
return hidden_states
309311

312+
def _load_weights(self,
313+
weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
314+
"""
315+
This function is copied from `MixtralModel.load_weights`, mainly to
316+
decouple from mixtral, avoiding impact on support like BNB
317+
quantization.
318+
"""
319+
stacked_params_mapping = [
320+
# (param_name, shard_name, shard_id)
321+
("qkv_proj", "q_proj", "q"),
322+
("qkv_proj", "k_proj", "k"),
323+
("qkv_proj", "v_proj", "v"),
324+
]
325+
326+
# Params for weights, fp8 weight scales, fp8 activation scales
327+
# (param_name, weight_name, expert_id, shard_id)
328+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
329+
ckpt_gate_proj_name="w1",
330+
ckpt_down_proj_name="w2",
331+
ckpt_up_proj_name="w3",
332+
num_experts=self.config.num_local_experts)
333+
334+
params_dict = dict(self.named_parameters())
335+
loaded_params: set[str] = set()
336+
for name, loaded_weight in weights:
337+
if (self.quant_config is not None and
338+
(scale_name := self.quant_config.get_cache_scale(name))):
339+
# Loading kv cache quantization scales
340+
param = params_dict[scale_name]
341+
weight_loader = getattr(param, "weight_loader",
342+
default_weight_loader)
343+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
344+
loaded_weight[0])
345+
weight_loader(param, loaded_weight)
346+
loaded_params.add(scale_name)
347+
continue
348+
349+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
350+
if weight_name not in name:
351+
continue
352+
name = name.replace(weight_name, param_name)
353+
# Skip loading extra bias for GPTQ models.
354+
if ((name.endswith(".bias") or name.endswith("_bias"))
355+
and name not in params_dict):
356+
continue
357+
# Skip layers on other devices.
358+
if is_pp_missing_parameter(name, self):
359+
continue
360+
if name.endswith("scale"):
361+
# Remapping the name of FP8 kv-scale.
362+
name = maybe_remap_kv_scale_name(name, params_dict)
363+
if name is None:
364+
continue
365+
param = params_dict[name]
366+
weight_loader = param.weight_loader
367+
weight_loader(param, loaded_weight, shard_id)
368+
break
369+
else:
370+
for mapping in expert_params_mapping:
371+
param_name, weight_name, expert_id, shard_id = mapping
372+
if weight_name not in name:
373+
continue
374+
name = name.replace(weight_name, param_name)
375+
# Skip layers on other devices.
376+
if is_pp_missing_parameter(name, self):
377+
continue
378+
if ((name.endswith(".bias") or name.endswith("_bias"))
379+
and name not in params_dict):
380+
continue
381+
param = params_dict[name]
382+
weight_loader = param.weight_loader
383+
weight_loader(param,
384+
loaded_weight,
385+
name,
386+
shard_id=shard_id,
387+
expert_id=expert_id)
388+
break
389+
else:
390+
# Skip loading extra bias for GPTQ models.
391+
if ((name.endswith(".bias") or name.endswith("_bias"))
392+
and name not in params_dict):
393+
continue
394+
# Skip layers on other devices.
395+
if is_pp_missing_parameter(name, self):
396+
continue
397+
# Remapping the name of FP8 kv-scale.
398+
name = maybe_remap_kv_scale_name(name, params_dict)
399+
if name is None:
400+
continue
401+
402+
param = params_dict[name]
403+
weight_loader = getattr(param, "weight_loader",
404+
default_weight_loader)
405+
weight_loader(param, loaded_weight)
406+
loaded_params.add(name)
407+
return loaded_params
408+
310409
def load_weights(self, weights: Iterable[tuple[str,
311410
torch.Tensor]]) -> set[str]:
312411
new_weights = {}
@@ -339,7 +438,7 @@ def load_weights(self, weights: Iterable[tuple[str,
339438
new_weights[gate_name] = p
340439
else:
341440
new_weights[n] = p
342-
return mixtral.MixtralModel.load_weights(self, new_weights.items())
441+
return self._load_weights(new_weights.items())
343442

344443

345444
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):

vllm/model_executor/models/granitemoeshared.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
from vllm.model_executor.sampling_metadata import SamplingMetadata
2828
from vllm.sequence import IntermediateTensors
2929

30-
from . import mixtral
31-
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE
30+
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
3231
from .interfaces import SupportsLoRA, SupportsPP
3332
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
3433

@@ -242,7 +241,7 @@ def load_weights(self, weights: Iterable[tuple[str,
242241
new_weights[gate_name] = p
243242
else:
244243
new_weights[n] = p
245-
return mixtral.MixtralModel.load_weights(self, new_weights.items())
244+
return GraniteMoeModel._load_weights(self, new_weights.items())
246245

247246

248247
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):

vllm/model_executor/models/mixtral.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,15 @@ def forward(
317317
hidden_states, _ = self.norm(hidden_states, residual)
318318
return hidden_states
319319

320+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
321+
# Params for weights, fp8 weight scales, fp8 activation scales
322+
# (param_name, weight_name, expert_id, shard_id)
323+
return FusedMoE.make_expert_params_mapping(
324+
ckpt_gate_proj_name="w1",
325+
ckpt_down_proj_name="w2",
326+
ckpt_up_proj_name="w3",
327+
num_experts=self.config.num_local_experts)
328+
320329
def load_weights(self, weights: Iterable[tuple[str,
321330
torch.Tensor]]) -> set[str]:
322331
stacked_params_mapping = [
@@ -326,16 +335,9 @@ def load_weights(self, weights: Iterable[tuple[str,
326335
("qkv_proj", "v_proj", "v"),
327336
]
328337

329-
# Params for weights, fp8 weight scales, fp8 activation scales
330-
# (param_name, weight_name, expert_id, shard_id)
331-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
332-
ckpt_gate_proj_name="w1",
333-
ckpt_down_proj_name="w2",
334-
ckpt_up_proj_name="w3",
335-
num_experts=self.config.num_local_experts)
336-
337338
params_dict = dict(self.named_parameters())
338339
loaded_params: set[str] = set()
340+
expert_params_mapping = self.get_expert_mapping()
339341
for name, loaded_weight in weights:
340342
if (self.quant_config is not None and
341343
(scale_name := self.quant_config.get_cache_scale(name))):
@@ -486,3 +488,6 @@ def load_weights(self, weights: Iterable[tuple[str,
486488
torch.Tensor]]) -> set[str]:
487489
loader = AutoWeightsLoader(self)
488490
return loader.load_weights(weights)
491+
492+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
493+
return self.model.get_expert_mapping()

vllm/model_executor/models/olmoe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def load_weights(self, weights: Iterable[tuple[str,
352352

353353
params_dict = dict(self.named_parameters())
354354
loaded_params: set[str] = set()
355+
expert_params_mapping = self.get_expert_mapping()
355356
for name, loaded_weight in weights:
356357
for (param_name, weight_name, shard_id) in stacked_params_mapping:
357358
# Skip non-stacked layers and experts (experts handled below).
@@ -380,7 +381,7 @@ def load_weights(self, weights: Iterable[tuple[str,
380381
weight_loader(param, loaded_weight, shard_id)
381382
break
382383
else:
383-
for mapping in self.get_expert_mapping():
384+
for mapping in expert_params_mapping:
384385
param_name, weight_name, expert_id, shard_id = mapping
385386
if weight_name not in name:
386387
continue

vllm/model_executor/models/qwen2_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ def load_weights(self, weights: Iterable[tuple[str,
413413

414414
params_dict = dict(self.named_parameters())
415415
loaded_params: set[str] = set()
416+
expert_params_mapping = self.get_expert_mapping()
416417
for name, loaded_weight in weights:
417418
for (param_name, weight_name, shard_id) in stacked_params_mapping:
418419
# Skip non-stacked layers and experts (experts handled below).
@@ -442,7 +443,7 @@ def load_weights(self, weights: Iterable[tuple[str,
442443
weight_loader(param, loaded_weight, shard_id)
443444
break
444445
else:
445-
for mapping in self.get_expert_mapping():
446+
for mapping in expert_params_mapping:
446447
param_name, weight_name, expert_id, shard_id = mapping
447448
if weight_name not in name:
448449
continue

vllm/model_executor/models/qwen3_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,11 +400,9 @@ def load_weights(self, weights: Iterable[tuple[str,
400400
".v_scale", "_v_scale", ".weight_scale",
401401
"_weight_scale", ".input_scale", "_input_scale")
402402

403-
# Params for weights, fp8 weight scales, fp8 activation scales
404-
# (param_name, weight_name, expert_id, shard_id)
405-
expert_params_mapping = self.get_expert_mapping()
406403
params_dict = dict(self.named_parameters())
407404
loaded_params: set[str] = set()
405+
expert_params_mapping = self.get_expert_mapping()
408406
for name, loaded_weight in weights:
409407
for (param_name, weight_name, shard_id) in stacked_params_mapping:
410408
# Skip non-stacked layers and experts (experts handled below).

0 commit comments

Comments
 (0)