-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[1/N] Refactor platform API to reduce torch.cuda
call
#20751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4299610
24f90da
a5ec772
5a3b6fa
d7ca0ed
81a1c1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -275,3 +275,30 @@ def default_v1(cls, model_config) -> bool: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
arch = cls.get_cpu_architecture() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return (cls.supports_v1(model_config) and arch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def empty_cache(cls): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def reset_peak_memory_stats(cls): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def mem_get_info(cls): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# FIXME: impl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+288
to
+290
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def memory_stats(cls): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# FIXME: impl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def memory_reserved(cls): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# FIXME: impl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+287
to
+300
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The stubs for
These methods should return sensible default values for the CPU platform to prevent crashes. For example:
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def synchronize(cls): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pass |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -194,3 +194,40 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: | |||||||||||||
@classmethod | ||||||||||||||
def device_count(cls) -> int: | ||||||||||||||
return torch.xpu.device_count() | ||||||||||||||
|
||||||||||||||
@classmethod | ||||||||||||||
def empty_cache(cls): | ||||||||||||||
torch.xpu.empty_cache() | ||||||||||||||
|
||||||||||||||
@classmethod | ||||||||||||||
def reset_peak_memory_stats(cls): | ||||||||||||||
torch.xpu.reset_peak_memory_stats() | ||||||||||||||
|
||||||||||||||
@classmethod | ||||||||||||||
def mem_get_info(cls): | ||||||||||||||
if cls.is_data_center_gpu(): | ||||||||||||||
return torch.xpu.mem_get_info() | ||||||||||||||
else: | ||||||||||||||
# we provide this function due to `torch.xpu.mem_get_info()` doesn't | ||||||||||||||
# return correct free_gpu_memory on intel client GPU. We need to | ||||||||||||||
# calculate/estiamte it. | ||||||||||||||
_, total_gpu_memory = torch.xpu.mem_get_info() | ||||||||||||||
# FIXME: memory_allocated() doesn't count non-torch allocations, | ||||||||||||||
# and we don't have any API to get it. so we mark it as 128MB. | ||||||||||||||
used_memory = torch.xpu.memory_allocated() | ||||||||||||||
non_torch_allocations = 128 * 1024 * 1024 | ||||||||||||||
free_gpu_memory = total_gpu_memory - (used_memory + | ||||||||||||||
non_torch_allocations) | ||||||||||||||
Comment on lines
+218
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This hardcoded value
Suggested change
|
||||||||||||||
return free_gpu_memory, total_gpu_memory | ||||||||||||||
|
||||||||||||||
@classmethod | ||||||||||||||
def memory_stats(cls): | ||||||||||||||
return torch.xpu.memory_stats() | ||||||||||||||
|
||||||||||||||
@classmethod | ||||||||||||||
def memory_reserved(cls): | ||||||||||||||
return torch.xpu.memory_reserved() | ||||||||||||||
|
||||||||||||||
@classmethod | ||||||||||||||
def synchronize(cls): | ||||||||||||||
torch.xpu.synchronize() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ | |
from vllm.multimodal import MULTIMODAL_REGISTRY | ||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange | ||
from vllm.multimodal.utils import group_mm_inputs_by_modality | ||
from vllm.platforms import current_platform | ||
from vllm.pooling_params import PoolingParams | ||
from vllm.sampling_params import SamplingType | ||
from vllm.sequence import IntermediateTensors | ||
|
@@ -345,12 +346,16 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: | |
def _init_device_properties(self) -> None: | ||
"""Initialize attributes from torch.cuda.get_device_properties | ||
""" | ||
self.device_properties = torch.cuda.get_device_properties(self.device) | ||
self.num_sms = self.device_properties.multi_processor_count | ||
if current_platform.is_cuda(): | ||
self.device_properties = torch.cuda.get_device_properties( | ||
self.device) | ||
self.num_sms = self.device_properties.multi_processor_count | ||
else: | ||
Comment on lines
+352
to
+353
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self.num_sms = None | ||
Comment on lines
348
to
+354
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The if current_platform.is_cuda():
self.device_properties = torch.cuda.get_device_properties(
self.device)
self.num_sms = self.device_properties.multi_processor_count
else:
self.num_sms = None |
||
|
||
# Note: used for model runner override. | ||
def _sync_device(self) -> None: | ||
torch.cuda.synchronize() | ||
current_platform.synchronize() | ||
|
||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: | ||
"""Update the cached states and the persistent batch with the scheduler | ||
|
@@ -2270,7 +2275,7 @@ def capture_model(self) -> None: | |
compilation_counter.num_gpu_runner_capture_triggers += 1 | ||
|
||
start_time = time.perf_counter() | ||
start_free_gpu_memory = torch.cuda.mem_get_info()[0] | ||
start_free_gpu_memory = current_platform.mem_get_info()[0] | ||
|
||
# Trigger CUDA graph capture for specific shapes. | ||
# Capture the large shapes first so that the smaller shapes | ||
|
@@ -2296,7 +2301,7 @@ def capture_model(self) -> None: | |
skip_eplb=True) | ||
|
||
end_time = time.perf_counter() | ||
end_free_gpu_memory = torch.cuda.mem_get_info()[0] | ||
end_free_gpu_memory = current_platform.mem_get_info()[0] | ||
elapsed_time = end_time - start_time | ||
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory | ||
# This usually takes 5~20 seconds. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and below. Same can be done with the standard pytorch API available starting from torch 2.6:
Are there actual benefits to define similar device abstraction on vLLM level? Using standard pytorch API will help to have a leaner vLLM code base. See https://docs.pytorch.org/docs/stable/generated/torch.accelerator.synchronize.html#torch.accelerator.synchronize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we'd better implment base class Platform::synchronize() method using
torch.accelerator.synchronize()
and leave it for platforms to implement their own in case there are any tricks, like pytorch/pytorch#155668There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be good compromise. Which torch version does vLLM target across device backends? Note that
torch.accelerator
is available from 2.6. If vLLM needs to support wider torch range, this can be a clear reason to abstract this API on the vLLM level. Also, if you see any missing APIs intorch.acclerator
, please, feedback - we are willing to take care of that on pytorch level.