Skip to content

Commit 728b770

Browse files
DimiChatzipavlisDimiChatzipavlis
DimiChatzipavlis
authored and
DimiChatzipavlis
committed
Format the code
1 parent e13528e commit 728b770

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

keras/src/callbacks/memory_usage_callback.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ class MemoryUsageCallback(Callback):
1919
This callback measures:
2020
2121
- **CPU**: via psutil.Process().memory_info().rss
22-
- **GPU/TPU**: via backendspecific APIs
23-
(TensorFlow, PyTorch, JAX, OpenVINO)
22+
- **GPU/TPU**: via backend-specific APIs
23+
(TensorFlow, PyTorch, JAX, OpenVINO)
2424
2525
Logs are printed to stdout at the start/end of each epoch and,
26-
if `log_every_batch=True`, after every batch. If `tensorboard_log_dir`
26+
if `log_every_batch=True`, after every batch. If `tensorboard_log_dir`
2727
is provided, scalars are also written via `tf.summary` (TensorBoard).
2828
2929
Args:
@@ -36,7 +36,6 @@ class MemoryUsageCallback(Callback):
3636
ImportError: If `psutil` is not installed (required for CPU logging).
3737
3838
Example:
39-
4039
```python
4140
from keras.callbacks import MemoryUsageCallback
4241
# ...
@@ -46,7 +45,7 @@ class MemoryUsageCallback(Callback):
4645
tensorboard_log_dir="./logs/memory"
4746
)
4847
model.fit(X, y, callbacks=[cb])
49-
```
48+
```
5049
"""
5150

5251
def __init__(
@@ -67,7 +66,6 @@ def __init__(
6766
if tensorboard_log_dir:
6867
try:
6968
import tensorflow as tf
70-
7169
logdir = os.path.expanduser(tensorboard_log_dir)
7270
self._writer = tf.summary.create_file_writer(logdir)
7371
print(f"MemoryUsageCallback: TensorBoard logs → {logdir}")
@@ -89,9 +87,7 @@ def on_epoch_end(self, epoch, logs=None):
8987

9088
def on_batch_end(self, batch, logs=None):
9189
if self.log_every_batch:
92-
self._log_step(
93-
f"Batch {self._step_counter} end", self._step_counter
94-
)
90+
self._log_step(f"Batch {self._step_counter} end", self._step_counter)
9591
self._step_counter += 1
9692

9793
def on_train_end(self, logs=None):
@@ -110,65 +106,66 @@ def _log_step(self, label, step):
110106
msg = f"{label} - CPU Memory: {cpu_mb:.2f} MB"
111107
if gpu_mb is not None:
112108
msg += f"; GPU Memory: {gpu_mb:.2f} MB"
113-
print(msg)
109+
# newline + flush ensures clean, immediate output
110+
print("\n" + msg, flush=True)
114111

115112
if self._writer:
116113
import tensorflow as tf # noqa: E501
117-
118114
with self._writer.as_default(step=int(step)):
119115
tf.summary.scalar("Memory/CPU_MB", cpu_mb)
120116
if gpu_mb is not None:
121117
tf.summary.scalar("Memory/GPU_MB", gpu_mb)
122118
self._writer.flush()
123119

124120
def _get_cpu_memory(self):
125-
return self._proc.memory_info().rss / (1024**2)
121+
return self._proc.memory_info().rss / (1024 ** 2)
126122

127123
def _get_gpu_memory(self):
128124
backend_name = K.backend()
129125
try:
130126
if backend_name == "tensorflow":
131127
import tensorflow as tf
132-
133128
gpus = tf.config.list_physical_devices("GPU")
134129
if not gpus:
135130
return None
136131
total = sum(
137132
tf.config.experimental.get_memory_info(g.name)["current"]
138133
for g in gpus
139134
)
140-
return total / (1024**2)
135+
return total / (1024 ** 2)
136+
141137
elif backend_name == "torch":
142138
import torch
143-
144139
if not torch.cuda.is_available():
145140
return None
146141
total = sum(
147142
torch.cuda.memory_allocated(i)
148143
for i in range(torch.cuda.device_count())
149144
)
150-
return total / (1024**2)
145+
return total / (1024 ** 2)
146+
151147
elif backend_name == "jax":
152148
import jax
153-
154149
devs = [d for d in jax.devices() if d.platform.upper() == "GPU"]
155150
if not devs:
156151
return None
157152
total = 0
158153
for d in devs:
159154
stats = getattr(d, "memory_stats", lambda: {})()
160155
total += stats.get("bytes_in_use", 0)
161-
return total / (1024**2)
162-
else:
163-
# OpenVINO and others fall back to unsupported
156+
return total / (1024 ** 2)
164157

158+
else:
159+
# OpenVINO and other unknown backends: warn once
165160
if not hasattr(self, "_warn_backend"):
166161
warnings.warn(
167-
f"MemoryUsageCallback: unsupported backend '{backend_name}'",
162+
"MemoryUsageCallback: unsupported backend "
163+
f"'{backend_name}'",
168164
RuntimeWarning,
169165
)
170166
self._warn_backend = True
171167
return None
168+
172169
except ImportError as imp_err:
173170
if not hasattr(self, "_warn_import"):
174171
warnings.warn(
@@ -177,6 +174,7 @@ def _get_gpu_memory(self):
177174
)
178175
self._warn_import = True
179176
return None
177+
180178
except Exception as exc:
181179
if not hasattr(self, "_warn_exc"):
182180
warnings.warn(

0 commit comments

Comments
 (0)