Skip to content

Commit a4c2331

Browse files
authored
[xpu]feat: support multi-lora on xpu (#20616)
Signed-off-by: yan <yan.ma@intel.com>
1 parent b942c09 commit a4c2331

File tree

5 files changed

+28
-4
lines changed

5 files changed

+28
-4
lines changed

vllm/lora/ops/triton_ops/lora_expand_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
1515
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
16+
from vllm.platforms import current_platform
1617
from vllm.utils import direct_register_custom_op
1718

1819

@@ -283,6 +284,7 @@ def _lora_expand_fake(
283284
op_func=_lora_expand,
284285
mutates_args=["output_tensor"],
285286
fake_impl=_lora_expand_fake,
287+
dispatch_key=current_platform.dispatch_key,
286288
)
287289
lora_expand = torch.ops.vllm.lora_expand
288290

vllm/lora/ops/triton_ops/lora_shrink_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
1515
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
16+
from vllm.platforms import current_platform
1617
from vllm.utils import direct_register_custom_op
1718

1819

@@ -237,6 +238,7 @@ def _lora_shrink_fake(
237238
op_func=_lora_shrink,
238239
mutates_args=["output_tensor"],
239240
fake_impl=_lora_shrink_fake,
241+
dispatch_key=current_platform.dispatch_key,
240242
)
241243
lora_shrink = torch.ops.vllm.lora_shrink
242244

vllm/lora/ops/triton_ops/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
3535
lora_strides_d1.append(lora_a_weight.stride(1))
3636
lora_strides_d2.append(lora_a_weight.stride(2))
3737
if len(lora_a_weights) > 1:
38-
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
38+
lora_ptr_tensor = torch.tensor(tensor_ptrs,
39+
device=device,
40+
dtype=torch.uint64)
3941
else:
4042
lora_ptr_tensor = lora_a_weights[0]
4143

@@ -89,8 +91,12 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int,
8991

9092
if len(lora_weights) > 1:
9193
# note these are device tensors
92-
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
93-
slice_start_tensor = torch.tensor(slice_offset_lst, device=device)
94+
lora_ptr_tensor = torch.tensor(tensor_ptrs,
95+
device=device,
96+
dtype=torch.uint64)
97+
slice_start_tensor = torch.tensor(slice_offset_lst,
98+
device=device,
99+
dtype=torch.uint64)
94100
else:
95101
slice_start_tensor = slice_offset_lst[0]
96102
lora_ptr_tensor = lora_b_weight[0]

vllm/model_executor/model_loader/tensorizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from vllm.logger import init_logger
2828
from vllm.model_executor.layers.vocab_parallel_embedding import (
2929
VocabParallelEmbedding)
30+
from vllm.platforms import current_platform
3031
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
3132

3233
if TYPE_CHECKING:
@@ -513,7 +514,9 @@ def deserialize_tensorizer_model(model: nn.Module,
513514
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
514515
stream,
515516
dtype=tensorizer_config.dtype,
516-
device=torch.device("cuda", torch.cuda.current_device()),
517+
device=f'xpu:{torch.xpu.current_device()}'
518+
if current_platform.is_xpu() else
519+
f'cuda:{torch.cuda.current_device()}',
517520
**tensorizer_args.deserialization_kwargs) as deserializer:
518521
deserializer.load_into_module(model)
519522
end = time.perf_counter()

vllm/platforms/xpu.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def get_device_capability(
5858
def get_device_name(cls, device_id: int = 0) -> str:
5959
return torch.xpu.get_device_name(device_id)
6060

61+
@classmethod
62+
def get_punica_wrapper(cls) -> str:
63+
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
64+
6165
@classmethod
6266
def get_device_total_memory(cls, device_id: int = 0) -> int:
6367
device_props = torch.xpu.get_device_properties(device_id)
@@ -78,6 +82,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
7882
if cache_config and cache_config.block_size is None:
7983
cache_config.block_size = 64
8084

85+
# FIXME: Temporarily forcing eager mode
86+
# remove after t.compile support stabilizes.
87+
if (envs.VLLM_USE_V1 and vllm_config.model_config is not None
88+
and not vllm_config.model_config.enforce_eager):
89+
from vllm.config import CompilationLevel
90+
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
91+
8192
# Instances created using VllmConfig() typically have model_config as
8293
# None by default. The modification involves adding a check to prevent
8394
# potential null exceptions check and update model config.

0 commit comments

Comments
 (0)