Skip to content

Commit e064d0e

Browse files
DimiChatzipavlisDimiChatzipavlis
DimiChatzipavlis
authored and
DimiChatzipavlis
committed
Fix openvino case
1 parent c4c0e5e commit e064d0e

File tree

1 file changed

+26
-25
lines changed

1 file changed

+26
-25
lines changed

keras/src/callbacks/memory_usage_callback.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from keras.src.api_export import keras_export
66
from keras.src.callbacks.callback import Callback
77

8-
# Attempt to import psutil for memory monitoring
8+
# Attempt to import psutil for CPU memory monitoring
99
try:
1010
import psutil
1111
except ImportError:
@@ -14,17 +14,19 @@
1414

1515
@keras_export("keras.callbacks.MemoryUsageCallback")
1616
class MemoryUsageCallback(Callback):
17-
"""Monitors and logs memory usage (CPU + optional GPU/TPU) during training.
17+
"""Monitors and logs memory usage
18+
(CPU + optional GPU/TPU/OpenVINO) during training.
1819
1920
This callback measures:
2021
2122
- **CPU**: via psutil.Process().memory_info().rss
22-
- **GPU/TPU**: via backend-specific APIs
23-
(TensorFlow, PyTorch, JAX, OpenVINO)
23+
- **GPU/TPU**: via backendspecific APIs
24+
(TensorFlow, PyTorch, JAX)
2425
2526
Logs are printed to stdout at the start/end of each epoch and,
26-
if `log_every_batch=True`, after every batch. If `tensorboard_log_dir`
27-
is provided, scalars are also written via `tf.summary` (TensorBoard).
27+
if `log_every_batch=True`, after every batch.
28+
If `tensorboard_log_dir` is provided, scalars are also written
29+
via `tf.summary` (TensorBoard).
2830
2931
Args:
3032
monitor_gpu (bool): If True, attempt to measure accelerator memory.
@@ -34,22 +36,13 @@ class MemoryUsageCallback(Callback):
3436
3537
Raises:
3638
ImportError: If `psutil` is not installed (required for CPU logging).
37-
38-
Example:
39-
```python
40-
from keras.callbacks import MemoryUsageCallback
41-
# ...
42-
cb = MemoryUsageCallback(
43-
monitor_gpu=True,
44-
log_every_batch=False,
45-
tensorboard_log_dir="./logs/memory"
46-
)
47-
model.fit(X, y, callbacks=[cb])
48-
```
4939
"""
5040

5141
def __init__(
52-
self, monitor_gpu=True, log_every_batch=False, tensorboard_log_dir=None
42+
self,
43+
monitor_gpu=True,
44+
log_every_batch=False,
45+
tensorboard_log_dir=None,
5346
):
5447
super().__init__()
5548
if psutil is None:
@@ -109,8 +102,7 @@ def _log_step(self, label, step):
109102
msg = f"{label} - CPU Memory: {cpu_mb:.2f} MB"
110103
if gpu_mb is not None:
111104
msg += f"; GPU Memory: {gpu_mb:.2f} MB"
112-
# newline + flush ensures clean, immediate output
113-
print("\n" + msg, flush=True)
105+
print(msg)
114106

115107
if self._writer:
116108
import tensorflow as tf # noqa: E501
@@ -119,7 +111,7 @@ def _log_step(self, label, step):
119111
tf.summary.scalar("Memory/CPU_MB", cpu_mb)
120112
if gpu_mb is not None:
121113
tf.summary.scalar("Memory/GPU_MB", gpu_mb)
122-
self._writer.flush()
114+
# flush happens inside writer
123115

124116
def _get_cpu_memory(self):
125117
return self._proc.memory_info().rss / (1024**2)
@@ -162,12 +154,21 @@ def _get_gpu_memory(self):
162154
total += stats.get("bytes_in_use", 0)
163155
return total / (1024**2)
164156

157+
elif backend_name == "openvino":
158+
# OpenVINO provides no memory-stats API:
159+
if not hasattr(self, "_warn_openvino"):
160+
warnings.warn(
161+
" OpenVINO does not expose memory stats; "
162+
"GPU monitoring disabled.",
163+
RuntimeWarning,
164+
)
165+
self._warn_openvino = True
166+
return None
167+
165168
else:
166-
# OpenVINO and other unknown backends: warn once
167169
if not hasattr(self, "_warn_backend"):
168170
warnings.warn(
169-
"MemoryUsageCallback: unsupported backend "
170-
f"'{backend_name}'",
171+
f"MemoryUsageCallback: no backend '{backend_name}'",
171172
RuntimeWarning,
172173
)
173174
self._warn_backend = True

0 commit comments

Comments
 (0)