Skip to content

Commit f340101

Browse files
Fix openvino support
1 parent 105cbdc commit f340101

File tree

2 files changed

+187
-173
lines changed

2 files changed

+187
-173
lines changed

keras/src/callbacks/memory_usage_callback.py

Lines changed: 107 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,138 @@
11
import os
22
import warnings
3-
43
from keras.src.api_export import keras_export
54
from keras.src.callbacks.callback import Callback
65
from keras.src import backend as K
76

87
try:
98
import psutil
109
except ImportError:
11-
psutil = None
12-
10+
psutil = None
1311

1412
@keras_export("keras.callbacks.MemoryUsageCallback")
1513
class MemoryUsageCallback(Callback):
16-
"""Monitor CPU/GPU/TPU/OpenVINO memory during training.
14+
"""Monitors and logs memory usage (CPU + optional GPU/TPU) during training.
15+
16+
This callback measures:
1717
18-
Tracks:
19-
- CPU memory via `psutil.Process().memory_info().rss`.
20-
- GPU memory via backend APIs (TF, Torch, JAX, OpenVINO).
21-
- Logs to stdout and, optionally, to TensorBoard.
18+
- **CPU**: via psutil.Process().memory_info().rss
19+
- **GPU/TPU**: via backend‐specific APIs (TensorFlow, PyTorch, JAX, OpenVINO)
20+
21+
Logs are printed to stdout at the start/end of each epoch and,
22+
if `log_every_batch=True`, after every batch. If `tensorboard_log_dir`
23+
is provided, scalars are also written via `tf.summary` (TensorBoard).
2224
2325
Args:
24-
monitor_gpu: Bool. If True, query GPU/accelerator memory.
25-
log_every_batch: Bool. If True, log after each batch.
26-
tensorboard_log_dir: str or None. If set, use TF summary writer.
26+
monitor_gpu (bool): If True, attempt to measure accelerator memory.
27+
log_every_batch (bool): If True, also log after each batch.
28+
tensorboard_log_dir (str|None): Directory for TensorBoard logs;
29+
if None, no TF summary writer is created.
2730
2831
Raises:
29-
ImportError: If `psutil` is missing.
32+
ImportError: If `psutil` is not installed (required for CPU logging).
33+
34+
Example:
35+
36+
```python
37+
from keras.callbacks import MemoryUsageCallback
38+
# ...
39+
cb = MemoryUsageCallback(
40+
monitor_gpu=True,
41+
log_every_batch=False,
42+
tensorboard_log_dir="./logs/memory"
43+
)
44+
model.fit(X, y, callbacks=[cb])
45+
```
3046
"""
3147

3248
def __init__(
3349
self, monitor_gpu=True, log_every_batch=False, tensorboard_log_dir=None
3450
):
3551
super().__init__()
3652
if psutil is None:
37-
raise ImportError("MemoryUsageCallback requires the 'psutil' library.")
53+
raise ImportError(
54+
"MemoryUsageCallback requires the 'psutil' library. "
55+
"Install via `pip install psutil`."
56+
)
3857
self.monitor_gpu = monitor_gpu
3958
self.log_every_batch = log_every_batch
40-
self.process = psutil.Process()
41-
self.tb_writer = None
42-
self._batches_seen = 0
59+
self._proc = psutil.Process()
60+
self._step_counter = 0
61+
self._writer = None
4362

4463
if tensorboard_log_dir:
4564
try:
46-
import tensorflow as tf
65+
import tensorflow as tf
4766

4867
logdir = os.path.expanduser(tensorboard_log_dir)
49-
self.tb_writer = tf.summary.create_file_writer(logdir)
68+
self._writer = tf.summary.create_file_writer(logdir)
69+
print(f"MemoryUsageCallback: TensorBoard logs → {logdir}")
5070
except Exception as e:
51-
warnings.warn(f"TB init error: {e}", RuntimeWarning)
71+
warnings.warn(
72+
f"Could not initialize TensorBoard writer: {e}", RuntimeWarning
73+
)
74+
self._writer = None
5275

5376
def on_train_begin(self, logs=None):
54-
self._batches_seen = 0
77+
self._step_counter = 0
5578

5679
def on_epoch_begin(self, epoch, logs=None):
57-
cpu = self._cpu_mem_mb()
58-
gpu = self._get_gpu_memory()
59-
self._log("Epoch %d start" % epoch, epoch, cpu, gpu)
80+
self._log_epoch("start", epoch)
6081

6182
def on_epoch_end(self, epoch, logs=None):
62-
cpu = self._cpu_mem_mb()
63-
gpu = self._get_gpu_memory()
64-
self._log("Epoch %d end" % epoch, epoch + 1, cpu, gpu)
83+
self._log_epoch("end", epoch, offset=1)
6584

6685
def on_batch_end(self, batch, logs=None):
6786
if self.log_every_batch:
68-
cpu = self._cpu_mem_mb()
69-
gpu = self._get_gpu_memory()
70-
self._log(f"Batch {self._batches_seen} end", self._batches_seen, cpu, gpu)
71-
self._batches_seen += 1
87+
self._log_step(f"Batch {self._step_counter} end", self._step_counter)
88+
self._step_counter += 1
7289

7390
def on_train_end(self, logs=None):
74-
if self.tb_writer:
75-
self.tb_writer.close()
91+
if self._writer:
92+
self._writer.close()
93+
94+
def _log_epoch(self, when, epoch, offset=0):
95+
label = f"Epoch {epoch} {when}"
96+
step = epoch + offset
97+
self._log_step(label, step)
98+
99+
def _log_step(self, label, step):
100+
cpu_mb = self._get_cpu_memory()
101+
gpu_mb = self._get_gpu_memory() if self.monitor_gpu else None
102+
103+
msg = f"{label} - CPU Memory: {cpu_mb:.2f} MB"
104+
if gpu_mb is not None:
105+
msg += f"; GPU Memory: {gpu_mb:.2f} MB"
106+
print(msg)
107+
108+
if self._writer:
109+
import tensorflow as tf # noqa: E501
76110

77-
def _cpu_mem_mb(self):
78-
return self.process.memory_info().rss / (1024**2)
111+
with self._writer.as_default(step=int(step)):
112+
tf.summary.scalar("Memory/CPU_MB", cpu_mb)
113+
if gpu_mb is not None:
114+
tf.summary.scalar("Memory/GPU_MB", gpu_mb)
115+
self._writer.flush()
116+
117+
def _get_cpu_memory(self):
118+
return self._proc.memory_info().rss / (1024**2)
79119

80120
def _get_gpu_memory(self):
81-
if not self.monitor_gpu:
82-
return None
83-
backend = K.backend()
121+
backend_name = K.backend()
84122
try:
85-
if backend == "tensorflow":
123+
if backend_name == "tensorflow":
86124
import tensorflow as tf
87125

88126
gpus = tf.config.list_physical_devices("GPU")
89127
if not gpus:
90128
return None
91-
total = 0
92-
for g in gpus:
93-
info = tf.config.experimental.get_memory_info(g.name)
94-
total += info.get("current", 0)
129+
total = sum(
130+
tf.config.experimental.get_memory_info(g.name)["current"]
131+
for g in gpus
132+
)
95133
return total / (1024**2)
96-
if backend == "torch":
97-
import torch
134+
elif backend_name == "torch":
135+
import torch
98136

99137
if not torch.cuda.is_available():
100138
return None
@@ -103,47 +141,37 @@ def _get_gpu_memory(self):
103141
for i in range(torch.cuda.device_count())
104142
)
105143
return total / (1024**2)
106-
if backend == "jax":
107-
import jax
144+
elif backend_name == "jax":
145+
import jax
108146

109147
devs = [d for d in jax.devices() if d.platform.upper() == "GPU"]
110148
if not devs:
111149
return None
112150
total = 0
113151
for d in devs:
114-
stats = d.memory_stats()
152+
stats = getattr(d, "memory_stats", lambda: {})()
115153
total += stats.get("bytes_in_use", 0)
116154
return total / (1024**2)
117-
if backend == "openvino":
118-
try:
119-
import openvino as ov
120-
121-
core = ov.Core()
122-
devices = core.available_devices
123-
total = 0
124-
for dev in devices:
125-
stats = core.get_property(dev, "DEVICE_MEMORY_STATISTICS")
126-
total += stats.get("deviceUsedBytes", 0)
127-
return total / (1024**2)
128-
except Exception as e:
129-
warnings.warn(f"OVINO mem err: {e}", RuntimeWarning)
130-
return None
131-
except ImportError as e:
132-
warnings.warn(f"Import err for {backend}: {e}", RuntimeWarning)
155+
else:
156+
# OpenVINO and others fall back to unsupported
157+
158+
if not hasattr(self, "_warn_backend"):
159+
warnings.warn(
160+
f"MemoryUsageCallback: unsupported backend '{backend_name}'",
161+
RuntimeWarning,
162+
)
163+
self._warn_backend = True
164+
return None
165+
except ImportError as imp_err:
166+
if not hasattr(self, "_warn_import"):
167+
warnings.warn(
168+
f"Could not import for backend '{backend_name}': {imp_err}",
169+
RuntimeWarning,
170+
)
171+
self._warn_import = True
172+
return None
173+
except Exception as exc:
174+
if not hasattr(self, "_warn_exc"):
175+
warnings.warn(f"Error retrieving GPU memory: {exc}", RuntimeWarning)
176+
self._warn_exc = True
133177
return None
134-
warnings.warn(f"Unsupported backend '{backend}'", RuntimeWarning)
135-
return None
136-
137-
def _log(self, label, step, cpu, gpu):
138-
msg = f"{label} - CPU: {cpu:.2f} MB"
139-
if gpu is not None:
140-
msg += f"; GPU: {gpu:.2f} MB"
141-
print(msg)
142-
if self.tb_writer:
143-
import tensorflow as tf
144-
145-
with self.tb_writer.as_default(step=step):
146-
tf.summary.scalar("Memory/CPU_MB", cpu)
147-
if gpu is not None:
148-
tf.summary.scalar("Memory/GPU_MB", gpu)
149-
self.tb_writer.flush()

0 commit comments

Comments
 (0)