|
50 | 50 | from vllm.v1.sample.sampler import Sampler
|
51 | 51 | from vllm.v1.utils import bind_kv_cache
|
52 | 52 | from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
| 53 | +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin |
53 | 54 |
|
54 | 55 | from vllm_ascend.attention.attention import AttentionMaskBuilder
|
55 | 56 | from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
@@ -102,7 +103,7 @@ def graph_capture(device: torch.device):
|
102 | 103 | yield graph_capture_context
|
103 | 104 |
|
104 | 105 |
|
105 |
| -class NPUModelRunner: |
| 106 | +class NPUModelRunner(LoRAModelRunnerMixin): |
106 | 107 |
|
107 | 108 | def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
108 | 109 | self.vllm_config = vllm_config
|
@@ -534,6 +535,10 @@ def _process_reqs(
|
534 | 535 | max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
535 | 536 | num_tokens)
|
536 | 537 |
|
| 538 | + # Hot-Swap lora model |
| 539 | + if self.lora_config: |
| 540 | + self.set_active_loras(self.input_batch, num_scheduled_tokens) |
| 541 | + |
537 | 542 | # Prepare positions
|
538 | 543 | req_indices = np.repeat(self.arange_np[:num_reqs],
|
539 | 544 | num_scheduled_tokens)
|
@@ -857,39 +862,55 @@ def _profile_multimodal(self) -> None:
|
857 | 862 |
|
858 | 863 | @torch.inference_mode()
|
859 | 864 | def _dummy_run(self, num_tokens: int) -> torch.Tensor:
|
860 |
| - model = self.model |
861 |
| - if self.is_multimodal_model: |
862 |
| - input_ids = None |
863 |
| - inputs_embeds = self.inputs_embeds[:num_tokens] |
864 |
| - else: |
865 |
| - input_ids = self.input_ids[:num_tokens] |
866 |
| - inputs_embeds = None |
| 865 | + # Set num_scheduled_tokens based on num_tokens and max_num_seqs |
| 866 | + # for dummy run with LoRA so that the num_reqs collectively |
| 867 | + # has num_tokens in total. |
| 868 | + assert num_tokens <= self.scheduler_config.max_num_batched_tokens |
| 869 | + max_num_reqs = self.scheduler_config.max_num_seqs |
| 870 | + num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens |
| 871 | + min_tokens_per_req = num_tokens // num_reqs |
| 872 | + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs |
| 873 | + num_scheduled_tokens_list[-1] += num_tokens % num_reqs |
| 874 | + assert sum(num_scheduled_tokens_list) == num_tokens |
| 875 | + assert len(num_scheduled_tokens_list) == num_reqs |
| 876 | + num_scheduled_tokens = np.array(num_scheduled_tokens_list, |
| 877 | + dtype=np.int32) |
| 878 | + with self.maybe_dummy_run_with_lora(self.lora_config, |
| 879 | + num_scheduled_tokens): |
| 880 | + model = self.model |
| 881 | + if self.is_multimodal_model: |
| 882 | + input_ids = None |
| 883 | + inputs_embeds = self.inputs_embeds[:num_tokens] |
| 884 | + else: |
| 885 | + input_ids = self.input_ids[:num_tokens] |
| 886 | + inputs_embeds = None |
867 | 887 |
|
868 |
| - if self.uses_mrope: |
869 |
| - positions = self.mrope_positions[:, :num_tokens] |
870 |
| - else: |
871 |
| - positions = self.positions[:num_tokens] |
| 888 | + if self.uses_mrope: |
| 889 | + positions = self.mrope_positions[:, :num_tokens] |
| 890 | + else: |
| 891 | + positions = self.positions[:num_tokens] |
872 | 892 |
|
873 |
| - if get_pp_group().is_first_rank: |
874 |
| - intermediate_tensors = None |
875 |
| - else: |
876 |
| - if self.intermediate_tensors is None: |
877 |
| - self.intermediate_tensors = ( |
878 |
| - self.model.make_empty_intermediate_tensors( |
879 |
| - batch_size=num_tokens, |
880 |
| - dtype=self.dtype, |
881 |
| - device=self.device)) |
882 |
| - intermediate_tensors = IntermediateTensors({ |
883 |
| - k: v[:num_tokens] |
884 |
| - for k, v in self.intermediate_tensors.items() |
885 |
| - }) |
886 |
| - |
887 |
| - with set_forward_context(None, self.vllm_config): |
888 |
| - hidden_states = model(input_ids=input_ids, |
889 |
| - positions=positions, |
890 |
| - intermediate_tensors=intermediate_tensors, |
891 |
| - inputs_embeds=inputs_embeds) |
892 |
| - return hidden_states |
| 893 | + if get_pp_group().is_first_rank: |
| 894 | + intermediate_tensors = None |
| 895 | + else: |
| 896 | + if self.intermediate_tensors is None: |
| 897 | + self.intermediate_tensors = ( |
| 898 | + self.model.make_empty_intermediate_tensors( |
| 899 | + batch_size=self.max_num_tokens, |
| 900 | + dtype=self.model_config.dtype, |
| 901 | + device=self.device)) |
| 902 | + intermediate_tensors = IntermediateTensors({ |
| 903 | + k: v[:num_tokens] |
| 904 | + for k, v in self.intermediate_tensors.items() |
| 905 | + }) |
| 906 | + |
| 907 | + with set_forward_context(None, self.vllm_config): |
| 908 | + hidden_states = model( |
| 909 | + input_ids=input_ids, |
| 910 | + positions=positions.to(self.device), |
| 911 | + intermediate_tensors=intermediate_tensors, |
| 912 | + inputs_embeds=inputs_embeds) |
| 913 | + return hidden_states |
893 | 914 |
|
894 | 915 | def profile_run(self) -> None:
|
895 | 916 | # Profile with multimodal encoder & encoder cache.
|
@@ -938,7 +959,11 @@ def load_model(self) -> None:
|
938 | 959 | with DeviceMemoryProfiler() as m: # noqa: SIM117
|
939 | 960 | self.model = get_model(vllm_config=self.vllm_config)
|
940 | 961 | if self.lora_config:
|
941 |
| - raise ValueError("LoRA model is not supported on NPU now.") |
| 962 | + self.model = self.load_lora_model(self.model, |
| 963 | + self.model_config, |
| 964 | + self.scheduler_config, |
| 965 | + self.lora_config, |
| 966 | + self.device) |
942 | 967 | logger.info("Loading model weights took %.4f GB",
|
943 | 968 | m.consumed_memory / float(2**30))
|
944 | 969 |
|
|
0 commit comments