|
51 | 51 | from vllm.sequence import IntermediateTensors
|
52 | 52 |
|
53 | 53 | from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
54 |
| -from .utils import (extract_layer_index, is_pp_missing_parameter, |
| 54 | +from .utils import (AutoWeightsLoader, extract_layer_index, |
| 55 | + is_pp_missing_parameter, |
55 | 56 | make_empty_intermediate_tensors_factory, make_layers,
|
56 | 57 | maybe_prefix)
|
57 | 58 |
|
@@ -286,6 +287,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
286 | 287 | cache_config = vllm_config.cache_config
|
287 | 288 | quant_config = vllm_config.quant_config
|
288 | 289 | lora_config = vllm_config.lora_config
|
| 290 | + self.quant_config = quant_config |
289 | 291 |
|
290 | 292 | self.config = config
|
291 | 293 | lora_vocab = (lora_config.lora_extra_vocab_size *
|
@@ -339,6 +341,62 @@ def forward(
|
339 | 341 | hidden_states, _ = self.norm(hidden_states, residual)
|
340 | 342 | return hidden_states
|
341 | 343 |
|
| 344 | + def load_weights(self, weights: Iterable[tuple[str, |
| 345 | + torch.Tensor]]) -> set[str]: |
| 346 | + stacked_params_mapping = [ |
| 347 | + # (param_name, shard_name, shard_id) |
| 348 | + ("qkv_proj", "q_proj", "q"), |
| 349 | + ("qkv_proj", "k_proj", "k"), |
| 350 | + ("qkv_proj", "v_proj", "v"), |
| 351 | + ("gate_up_proj", "gate_proj", 0), |
| 352 | + ("gate_up_proj", "up_proj", 1), |
| 353 | + ] |
| 354 | + params_dict = dict(self.named_parameters()) |
| 355 | + loaded_params: set[str] = set() |
| 356 | + for name, loaded_weight in weights: |
| 357 | + if (self.quant_config is not None and |
| 358 | + (scale_name := self.quant_config.get_cache_scale(name))): |
| 359 | + # Loading kv cache quantization scales |
| 360 | + param = params_dict[scale_name] |
| 361 | + weight_loader = getattr(param, "weight_loader", |
| 362 | + default_weight_loader) |
| 363 | + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else |
| 364 | + loaded_weight[0]) |
| 365 | + weight_loader(param, loaded_weight) |
| 366 | + loaded_params.add(scale_name) |
| 367 | + continue |
| 368 | + |
| 369 | + for param_name, shard_name, shard_id in stacked_params_mapping: |
| 370 | + if shard_name not in name: |
| 371 | + continue |
| 372 | + name = name.replace(shard_name, param_name) |
| 373 | + # Skip loading extra bias for GPTQ models. |
| 374 | + if name.endswith(".bias") and name not in params_dict: |
| 375 | + continue |
| 376 | + if is_pp_missing_parameter(name, self): |
| 377 | + continue |
| 378 | + param = params_dict[name] |
| 379 | + weight_loader = param.weight_loader |
| 380 | + weight_loader(param, loaded_weight, shard_id) |
| 381 | + break |
| 382 | + else: |
| 383 | + # Skip loading extra bias for GPTQ models. |
| 384 | + if name.endswith(".bias") and name not in params_dict: |
| 385 | + continue |
| 386 | + # Remapping the name of FP8 kv-scale. |
| 387 | + name = maybe_remap_kv_scale_name(name, params_dict) |
| 388 | + if name is None: |
| 389 | + continue |
| 390 | + |
| 391 | + if is_pp_missing_parameter(name, self): |
| 392 | + continue |
| 393 | + param = params_dict[name] |
| 394 | + weight_loader = getattr(param, "weight_loader", |
| 395 | + default_weight_loader) |
| 396 | + weight_loader(param, loaded_weight) |
| 397 | + loaded_params.add(name) |
| 398 | + return loaded_params |
| 399 | + |
342 | 400 |
|
343 | 401 | class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
|
344 | 402 | packed_modules_mapping = {
|
@@ -408,65 +466,6 @@ def compute_logits(
|
408 | 466 |
|
409 | 467 | def load_weights(self, weights: Iterable[tuple[str,
|
410 | 468 | torch.Tensor]]) -> set[str]:
|
411 |
| - stacked_params_mapping = [ |
412 |
| - # (param_name, shard_name, shard_id) |
413 |
| - ("qkv_proj", "q_proj", "q"), |
414 |
| - ("qkv_proj", "k_proj", "k"), |
415 |
| - ("qkv_proj", "v_proj", "v"), |
416 |
| - ("gate_up_proj", "gate_proj", 0), |
417 |
| - ("gate_up_proj", "up_proj", 1), |
418 |
| - ] |
419 |
| - params_dict = dict(self.named_parameters()) |
420 |
| - loaded_params: set[str] = set() |
421 |
| - for name, loaded_weight in weights: |
422 |
| - |
423 |
| - # Skip loading rotary embeddings since vLLM has its own |
424 |
| - if "rotary_emb.inv_freq" in name: |
425 |
| - continue |
426 |
| - |
427 |
| - if (self.quant_config is not None and |
428 |
| - (scale_name := self.quant_config.get_cache_scale(name))): |
429 |
| - # Loading kv cache quantization scales |
430 |
| - param = params_dict[scale_name] |
431 |
| - weight_loader = getattr(param, "weight_loader", |
432 |
| - default_weight_loader) |
433 |
| - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else |
434 |
| - loaded_weight[0]) |
435 |
| - weight_loader(param, loaded_weight) |
436 |
| - loaded_params.add(scale_name) |
437 |
| - continue |
438 |
| - |
439 |
| - for param_name, shard_name, shard_id in stacked_params_mapping: |
440 |
| - if shard_name not in name: |
441 |
| - continue |
442 |
| - name = name.replace(shard_name, param_name) |
443 |
| - # Skip loading extra bias for GPTQ models. |
444 |
| - if name.endswith(".bias") and name not in params_dict: |
445 |
| - continue |
446 |
| - if is_pp_missing_parameter(name, self): |
447 |
| - continue |
448 |
| - param = params_dict[name] |
449 |
| - weight_loader = param.weight_loader |
450 |
| - weight_loader(param, loaded_weight, shard_id) |
451 |
| - break |
452 |
| - else: |
453 |
| - # lm_head is not used in vllm as it is tied with embed_token. |
454 |
| - # To prevent errors, skip loading lm_head.weight. |
455 |
| - if "lm_head.weight" in name: |
456 |
| - continue |
457 |
| - # Skip loading extra bias for GPTQ models. |
458 |
| - if name.endswith(".bias") and name not in params_dict: |
459 |
| - continue |
460 |
| - # Remapping the name of FP8 kv-scale. |
461 |
| - name = maybe_remap_kv_scale_name(name, params_dict) |
462 |
| - if name is None: |
463 |
| - continue |
464 |
| - |
465 |
| - if is_pp_missing_parameter(name, self): |
466 |
| - continue |
467 |
| - param = params_dict[name] |
468 |
| - weight_loader = getattr(param, "weight_loader", |
469 |
| - default_weight_loader) |
470 |
| - weight_loader(param, loaded_weight) |
471 |
| - loaded_params.add(name) |
472 |
| - return loaded_params |
| 469 | + loader = AutoWeightsLoader( |
| 470 | + self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"]) |
| 471 | + return loader.load_weights(weights) |
0 commit comments