Skip to content

Commit b982476

Browse files
committed
add/fix interface for cpu/cuda/rocm/xpu
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent 7890945 commit b982476

File tree

5 files changed

+57
-6
lines changed

5 files changed

+57
-6
lines changed

vllm/platforms/cpu.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,30 @@ def default_v1(cls, model_config) -> bool:
275275
arch = cls.get_cpu_architecture()
276276
return (cls.supports_v1(model_config) and arch
277277
in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM))
278+
279+
@classmethod
280+
def empty_cache(cls):
281+
pass
282+
283+
@classmethod
284+
def reset_peak_memory_stats(cls):
285+
pass
286+
287+
@classmethod
288+
def mem_get_info(cls):
289+
# FIXME: impl
290+
return None
291+
292+
@classmethod
293+
def memory_stats(cls):
294+
# FIXME: impl
295+
return None
296+
297+
@classmethod
298+
def memory_reserved(cls):
299+
# FIXME: impl
300+
return None
301+
302+
@classmethod
303+
def synchronize(cls):
304+
torch.cpu.synchronize()

vllm/platforms/cuda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def device_count(cls) -> int:
425425
return cuda_device_count_stateless()
426426

427427
@classmethod
428-
def empty_cache(cls, ):
428+
def empty_cache(cls):
429429
torch.cuda.empty_cache()
430430

431431
@classmethod
@@ -446,7 +446,7 @@ def memory_reserved(cls):
446446

447447
@classmethod
448448
def synchronize(cls):
449-
return torch.cuda.synchronize()
449+
torch.cuda.synchronize()
450450

451451

452452
# NVML utils

vllm/platforms/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def stateless_init_device_torch_dist_pg(
549549
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
550550

551551
@classmethod
552-
def empty_cache(cls, ):
552+
def empty_cache(cls):
553553
raise NotImplementedError
554554

555555
@classmethod
@@ -570,7 +570,7 @@ def memory_reserved(cls):
570570

571571
@classmethod
572572
def synchronize(cls):
573-
torch.accelerator.synchronize()
573+
raise NotImplementedError
574574

575575

576576
class UnspecifiedPlatform(Platform):

vllm/platforms/rocm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,3 +463,27 @@ def stateless_init_device_torch_dist_pg(
463463
@classmethod
464464
def device_count(cls) -> int:
465465
return cuda_device_count_stateless()
466+
467+
@classmethod
468+
def empty_cache(cls):
469+
torch.cuda.empty_cache()
470+
471+
@classmethod
472+
def reset_peak_memory_stats(cls):
473+
torch.cuda.reset_peak_memory_stats()
474+
475+
@classmethod
476+
def mem_get_info(cls):
477+
return torch.cuda.mem_get_info()
478+
479+
@classmethod
480+
def memory_stats(cls):
481+
return torch.cuda.memory_stats()
482+
483+
@classmethod
484+
def memory_reserved(cls):
485+
return torch.cuda.memory_reserved()
486+
487+
@classmethod
488+
def synchronize(cls):
489+
torch.cuda.synchronize()

vllm/platforms/xpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def device_count(cls) -> int:
196196
return torch.xpu.device_count()
197197

198198
@classmethod
199-
def empty_cache(cls, ):
199+
def empty_cache(cls):
200200
torch.xpu.empty_cache()
201201

202202
@classmethod
@@ -230,4 +230,4 @@ def memory_reserved(cls):
230230

231231
@classmethod
232232
def synchronize(cls):
233-
return torch.xpu.synchronize()
233+
torch.xpu.synchronize()

0 commit comments

Comments
 (0)