Skip to content

Commit c8c8ca6

Browse files
py-andy-cminpeter
authored andcommitted
[Model] use AutoWeightsLoader for commandr (vllm-project#19399)
Signed-off-by: py-andy-c <pychen1017@gmail.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent c9425b7 commit c8c8ca6

File tree

1 file changed

+62
-63
lines changed

1 file changed

+62
-63
lines changed

vllm/model_executor/models/commandr.py

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
from vllm.sequence import IntermediateTensors
5252

5353
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
54-
from .utils import (extract_layer_index, is_pp_missing_parameter,
54+
from .utils import (AutoWeightsLoader, extract_layer_index,
55+
is_pp_missing_parameter,
5556
make_empty_intermediate_tensors_factory, make_layers,
5657
maybe_prefix)
5758

@@ -286,6 +287,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
286287
cache_config = vllm_config.cache_config
287288
quant_config = vllm_config.quant_config
288289
lora_config = vllm_config.lora_config
290+
self.quant_config = quant_config
289291

290292
self.config = config
291293
lora_vocab = (lora_config.lora_extra_vocab_size *
@@ -339,6 +341,62 @@ def forward(
339341
hidden_states, _ = self.norm(hidden_states, residual)
340342
return hidden_states
341343

344+
def load_weights(self, weights: Iterable[tuple[str,
345+
torch.Tensor]]) -> set[str]:
346+
stacked_params_mapping = [
347+
# (param_name, shard_name, shard_id)
348+
("qkv_proj", "q_proj", "q"),
349+
("qkv_proj", "k_proj", "k"),
350+
("qkv_proj", "v_proj", "v"),
351+
("gate_up_proj", "gate_proj", 0),
352+
("gate_up_proj", "up_proj", 1),
353+
]
354+
params_dict = dict(self.named_parameters())
355+
loaded_params: set[str] = set()
356+
for name, loaded_weight in weights:
357+
if (self.quant_config is not None and
358+
(scale_name := self.quant_config.get_cache_scale(name))):
359+
# Loading kv cache quantization scales
360+
param = params_dict[scale_name]
361+
weight_loader = getattr(param, "weight_loader",
362+
default_weight_loader)
363+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
364+
loaded_weight[0])
365+
weight_loader(param, loaded_weight)
366+
loaded_params.add(scale_name)
367+
continue
368+
369+
for param_name, shard_name, shard_id in stacked_params_mapping:
370+
if shard_name not in name:
371+
continue
372+
name = name.replace(shard_name, param_name)
373+
# Skip loading extra bias for GPTQ models.
374+
if name.endswith(".bias") and name not in params_dict:
375+
continue
376+
if is_pp_missing_parameter(name, self):
377+
continue
378+
param = params_dict[name]
379+
weight_loader = param.weight_loader
380+
weight_loader(param, loaded_weight, shard_id)
381+
break
382+
else:
383+
# Skip loading extra bias for GPTQ models.
384+
if name.endswith(".bias") and name not in params_dict:
385+
continue
386+
# Remapping the name of FP8 kv-scale.
387+
name = maybe_remap_kv_scale_name(name, params_dict)
388+
if name is None:
389+
continue
390+
391+
if is_pp_missing_parameter(name, self):
392+
continue
393+
param = params_dict[name]
394+
weight_loader = getattr(param, "weight_loader",
395+
default_weight_loader)
396+
weight_loader(param, loaded_weight)
397+
loaded_params.add(name)
398+
return loaded_params
399+
342400

343401
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
344402
packed_modules_mapping = {
@@ -408,65 +466,6 @@ def compute_logits(
408466

409467
def load_weights(self, weights: Iterable[tuple[str,
410468
torch.Tensor]]) -> set[str]:
411-
stacked_params_mapping = [
412-
# (param_name, shard_name, shard_id)
413-
("qkv_proj", "q_proj", "q"),
414-
("qkv_proj", "k_proj", "k"),
415-
("qkv_proj", "v_proj", "v"),
416-
("gate_up_proj", "gate_proj", 0),
417-
("gate_up_proj", "up_proj", 1),
418-
]
419-
params_dict = dict(self.named_parameters())
420-
loaded_params: set[str] = set()
421-
for name, loaded_weight in weights:
422-
423-
# Skip loading rotary embeddings since vLLM has its own
424-
if "rotary_emb.inv_freq" in name:
425-
continue
426-
427-
if (self.quant_config is not None and
428-
(scale_name := self.quant_config.get_cache_scale(name))):
429-
# Loading kv cache quantization scales
430-
param = params_dict[scale_name]
431-
weight_loader = getattr(param, "weight_loader",
432-
default_weight_loader)
433-
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
434-
loaded_weight[0])
435-
weight_loader(param, loaded_weight)
436-
loaded_params.add(scale_name)
437-
continue
438-
439-
for param_name, shard_name, shard_id in stacked_params_mapping:
440-
if shard_name not in name:
441-
continue
442-
name = name.replace(shard_name, param_name)
443-
# Skip loading extra bias for GPTQ models.
444-
if name.endswith(".bias") and name not in params_dict:
445-
continue
446-
if is_pp_missing_parameter(name, self):
447-
continue
448-
param = params_dict[name]
449-
weight_loader = param.weight_loader
450-
weight_loader(param, loaded_weight, shard_id)
451-
break
452-
else:
453-
# lm_head is not used in vllm as it is tied with embed_token.
454-
# To prevent errors, skip loading lm_head.weight.
455-
if "lm_head.weight" in name:
456-
continue
457-
# Skip loading extra bias for GPTQ models.
458-
if name.endswith(".bias") and name not in params_dict:
459-
continue
460-
# Remapping the name of FP8 kv-scale.
461-
name = maybe_remap_kv_scale_name(name, params_dict)
462-
if name is None:
463-
continue
464-
465-
if is_pp_missing_parameter(name, self):
466-
continue
467-
param = params_dict[name]
468-
weight_loader = getattr(param, "weight_loader",
469-
default_weight_loader)
470-
weight_loader(param, loaded_weight)
471-
loaded_params.add(name)
472-
return loaded_params
469+
loader = AutoWeightsLoader(
470+
self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"])
471+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)