File tree Expand file tree Collapse file tree 5 files changed +57
-6
lines changed Expand file tree Collapse file tree 5 files changed +57
-6
lines changed Original file line number Diff line number Diff line change @@ -275,3 +275,30 @@ def default_v1(cls, model_config) -> bool:
275
275
arch = cls .get_cpu_architecture ()
276
276
return (cls .supports_v1 (model_config ) and arch
277
277
in (CpuArchEnum .X86 , CpuArchEnum .POWERPC , CpuArchEnum .ARM ))
278
+
279
+ @classmethod
280
+ def empty_cache (cls ):
281
+ pass
282
+
283
+ @classmethod
284
+ def reset_peak_memory_stats (cls ):
285
+ pass
286
+
287
+ @classmethod
288
+ def mem_get_info (cls ):
289
+ # FIXME: impl
290
+ return None
291
+
292
+ @classmethod
293
+ def memory_stats (cls ):
294
+ # FIXME: impl
295
+ return None
296
+
297
+ @classmethod
298
+ def memory_reserved (cls ):
299
+ # FIXME: impl
300
+ return None
301
+
302
+ @classmethod
303
+ def synchronize (cls ):
304
+ torch .cpu .synchronize ()
Original file line number Diff line number Diff line change @@ -425,7 +425,7 @@ def device_count(cls) -> int:
425
425
return cuda_device_count_stateless ()
426
426
427
427
@classmethod
428
- def empty_cache (cls , ):
428
+ def empty_cache (cls ):
429
429
torch .cuda .empty_cache ()
430
430
431
431
@classmethod
@@ -446,7 +446,7 @@ def memory_reserved(cls):
446
446
447
447
@classmethod
448
448
def synchronize (cls ):
449
- return torch .cuda .synchronize ()
449
+ torch .cuda .synchronize ()
450
450
451
451
452
452
# NVML utils
Original file line number Diff line number Diff line change @@ -549,7 +549,7 @@ def stateless_init_device_torch_dist_pg(
549
549
raise RuntimeError (f"Unsupported torch distributed backend: { backend } " )
550
550
551
551
@classmethod
552
- def empty_cache (cls , ):
552
+ def empty_cache (cls ):
553
553
raise NotImplementedError
554
554
555
555
@classmethod
@@ -570,7 +570,7 @@ def memory_reserved(cls):
570
570
571
571
@classmethod
572
572
def synchronize (cls ):
573
- torch . accelerator . synchronize ()
573
+ raise NotImplementedError
574
574
575
575
576
576
class UnspecifiedPlatform (Platform ):
Original file line number Diff line number Diff line change @@ -463,3 +463,27 @@ def stateless_init_device_torch_dist_pg(
463
463
@classmethod
464
464
def device_count (cls ) -> int :
465
465
return cuda_device_count_stateless ()
466
+
467
+ @classmethod
468
+ def empty_cache (cls ):
469
+ torch .cuda .empty_cache ()
470
+
471
+ @classmethod
472
+ def reset_peak_memory_stats (cls ):
473
+ torch .cuda .reset_peak_memory_stats ()
474
+
475
+ @classmethod
476
+ def mem_get_info (cls ):
477
+ return torch .cuda .mem_get_info ()
478
+
479
+ @classmethod
480
+ def memory_stats (cls ):
481
+ return torch .cuda .memory_stats ()
482
+
483
+ @classmethod
484
+ def memory_reserved (cls ):
485
+ return torch .cuda .memory_reserved ()
486
+
487
+ @classmethod
488
+ def synchronize (cls ):
489
+ torch .cuda .synchronize ()
Original file line number Diff line number Diff line change @@ -196,7 +196,7 @@ def device_count(cls) -> int:
196
196
return torch .xpu .device_count ()
197
197
198
198
@classmethod
199
- def empty_cache (cls , ):
199
+ def empty_cache (cls ):
200
200
torch .xpu .empty_cache ()
201
201
202
202
@classmethod
@@ -230,4 +230,4 @@ def memory_reserved(cls):
230
230
231
231
@classmethod
232
232
def synchronize (cls ):
233
- return torch .xpu .synchronize ()
233
+ torch .xpu .synchronize ()
You can’t perform that action at this time.
0 commit comments