@@ -2553,10 +2553,10 @@ def __post_init__(self):
2553
2553
2554
2554
def measure (self ):
2555
2555
# we measure the torch peak memory usage via allocated_bytes,
2556
- # rather than `torch.cuda .memory_reserved()` .
2557
- # After `torch.cuda .reset_peak_memory_stats()`,
2558
- # `torch.cuda. memory_reserved()` will keep growing, and only shrink
2559
- # when we call `current_platform.empty_cache()` or OOM happens.
2556
+ # rather than `current_platform .memory_reserved()` .
2557
+ # After `current_platform .reset_peak_memory_stats()`,
2558
+ # `current_platform. memory_reserved()` will keep growing, and only
2559
+ # shrink when we call `current_platform.empty_cache()` or OOM happens.
2560
2560
from vllm .platforms import current_platform
2561
2561
2562
2562
self .torch_peak = current_platform .memory_stats ().get (
@@ -2565,7 +2565,7 @@ def measure(self):
2565
2565
self .free_memory , self .total_memory = current_platform .mem_get_info ()
2566
2566
self .cuda_memory = self .total_memory - self .free_memory
2567
2567
2568
- # torch.cuda .memory_reserved() is how many bytes
2568
+ # current_platform .memory_reserved() is how many bytes
2569
2569
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
2570
2570
# this is used to measure the non-torch memory usage
2571
2571
self .torch_memory = current_platform .memory_reserved ()
0 commit comments