Skip to content

Commit 7484e1f

Browse files
authored
Add cache to cuda get_device_capability (#19436)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent a2142f0 commit 7484e1f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm/platforms/cuda.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import os
88
from datetime import timedelta
9-
from functools import wraps
9+
from functools import cache, wraps
1010
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
1111

1212
import torch
@@ -389,6 +389,7 @@ def stateless_init_device_torch_dist_pg(
389389
class NvmlCudaPlatform(CudaPlatformBase):
390390

391391
@classmethod
392+
@cache
392393
@with_nvml_context
393394
def get_device_capability(cls,
394395
device_id: int = 0
@@ -486,6 +487,7 @@ def log_warnings(cls):
486487
class NonNvmlCudaPlatform(CudaPlatformBase):
487488

488489
@classmethod
490+
@cache
489491
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
490492
major, minor = torch.cuda.get_device_capability(device_id)
491493
return DeviceCapability(major=major, minor=minor)

0 commit comments

Comments
 (0)