Skip to content

Commit 5f9d975

Browse files
Add memory usage monitor callback
1 parent 7695601 commit 5f9d975

File tree

2 files changed

+200
-313
lines changed

2 files changed

+200
-313
lines changed
Lines changed: 115 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -1,238 +1,166 @@
11
import os
22
import warnings
3-
import tensorflow as tf # Ensure TF is imported for tf.summary
4-
from keras.src import backend as K
3+
54
from keras.src.api_export import keras_export
65
from keras.src.callbacks.callback import Callback
6+
from keras.src import backend as K
77

8-
# Attempt to import psutil and warn if unavailable.
8+
# Attempt to import psutil for CPU memory
99
try:
1010
import psutil
1111
except ImportError:
1212
psutil = None
1313

14+
1415
@keras_export("keras.callbacks.MemoryUsageCallback")
1516
class MemoryUsageCallback(Callback):
16-
"""Callback for enhanced monitoring of memory usage during training.
17-
18-
This callback tracks CPU memory usage via `psutil.Process().memory_info().rss`
19-
and optionally GPU memory usage via backend-specific APIs (TensorFlow, PyTorch, JAX).
20-
Memory statistics are logged to stdout at the start and end of each epoch and,
21-
optionally, after every batch. Additionally, metrics are logged to TensorBoard
22-
if a log directory is provided, using integer steps for proper visualization.
23-
24-
Note: GPU memory reporting consistency across backends (TF, PyTorch, JAX)
25-
may vary, as they use different underlying mechanisms to measure usage
26-
(e.g., framework overhead vs. purely tensor allocations).
27-
28-
Args:
29-
monitor_gpu (bool): Whether to monitor GPU memory. Defaults to True.
30-
Requires appropriate backend (TensorFlow, PyTorch, JAX) with GPU
31-
support and necessary drivers/libraries installed.
32-
log_every_batch (bool): Whether to log memory usage after each batch
33-
in addition to epoch start/end. Defaults to False.
34-
tensorboard_log_dir (str, optional): Path to the directory where TensorBoard
35-
logs will be written using `tf.summary`. If None, TensorBoard logging
36-
is disabled. Defaults to None. Requires TensorFlow to be installed.
37-
38-
Raises:
39-
ImportError: If `psutil` is not installed.
17+
"""
18+
Monitors CPU and GPU memory across backends and logs to stdout and TensorBoard.
4019
4120
Example:
4221
```python
43-
import tensorflow as tf
44-
import keras
45-
from keras.callbacks import MemoryUsageCallback # Use public API path
46-
import numpy as np
47-
48-
# Ensure psutil is installed: pip install psutil
49-
50-
memory_callback = MemoryUsageCallback(
51-
monitor_gpu=True, # Set based on GPU availability and backend support
22+
from keras.callbacks import MemoryUsageCallback
23+
callback = MemoryUsageCallback(
24+
monitor_gpu=True,
5225
log_every_batch=False,
53-
tensorboard_log_dir="~/logs/memory_usage" # Needs TF installed
26+
tensorboard_log_dir="./logs"
5427
)
28+
model.fit(..., callbacks=[callback])
29+
```
5530
56-
model = keras.models.Sequential([
57-
keras.layers.Dense(64, activation='relu', input_shape=(100,)),
58-
keras.layers.Dense(10, activation='softmax')
59-
])
60-
model.compile(optimizer='adam', loss='categorical_crossentropy')
31+
Args:
32+
monitor_gpu (bool): Whether to log GPU memory. Defaults to True.
33+
log_every_batch (bool): Whether to log after every batch. Defaults to False.
34+
tensorboard_log_dir (str): Directory for TensorBoard logs; None disables. Defaults to None.
6135
62-
x_train = np.random.random((100, 100))
63-
y_train = keras.utils.to_categorical(
64-
np.random.randint(10, size=(100, 1)), num_classes=10
65-
)
66-
model.fit(x_train, y_train, epochs=5, batch_size=32, callbacks=[memory_callback])
67-
```
36+
Raises:
37+
ImportError: If psutil is not installed.
6838
"""
69-
def __init__(self, monitor_gpu=True, log_every_batch=False, tensorboard_log_dir=None):
39+
40+
def __init__(
41+
self,
42+
monitor_gpu=True,
43+
log_every_batch=False,
44+
tensorboard_log_dir=None,
45+
):
7046
super().__init__()
7147
if psutil is None:
7248
raise ImportError(
73-
"MemoryUsageCallback requires the 'psutil' library. "
74-
"Please install it using 'pip install psutil'."
49+
"MemoryUsageCallback requires `psutil`; install via `pip install psutil`."
7550
)
7651
self.monitor_gpu = monitor_gpu
7752
self.log_every_batch = log_every_batch
7853
self.process = psutil.Process()
7954
self.tb_writer = None
80-
self._total_batches_seen = 0 # For TensorBoard step counting
55+
self._batch_count = 0
8156

8257
if tensorboard_log_dir:
83-
# tf.summary requires TensorFlow installed.
84-
if tf is None:
85-
warnings.warn(
86-
"MemoryUsageCallback: TensorFlow is required for TensorBoard logging. "
87-
"Please install TensorFlow.", ImportWarning
88-
)
89-
self.tb_writer = None
90-
else:
91-
try:
92-
log_dir = os.path.expanduser(tensorboard_log_dir)
93-
# Use tf.summary for robust integration
94-
self.tb_writer = tf.summary.create_file_writer(log_dir)
95-
print(f"MemoryUsageCallback: TensorBoard logging initialized at {log_dir}")
96-
except Exception as e:
97-
warnings.warn(f"Error initializing TensorBoard writer: {e}", RuntimeWarning)
98-
self.tb_writer = None
99-
100-
def on_train_begin(self, logs=None):
101-
"""Reset batch counter at the start of training."""
102-
self._total_batches_seen = 0
103-
104-
def on_epoch_begin(self, epoch, logs=None):
105-
"""Log memory usage at the beginning of each epoch."""
106-
cpu_mem = self._get_cpu_memory()
107-
gpu_mem = self._get_gpu_memory()
108-
self._log_memory(
109-
label=f"Epoch {epoch} start",
110-
step=epoch, # Use epoch number for TB step
111-
cpu_mem=cpu_mem,
112-
gpu_mem=gpu_mem
113-
)
114-
115-
def on_epoch_end(self, epoch, logs=None):
116-
"""Log memory usage at the end of each epoch."""
117-
cpu_mem = self._get_cpu_memory()
118-
gpu_mem = self._get_gpu_memory()
119-
# Use epoch + 1 for TB step to mark the end point distinctly
120-
self._log_memory(
121-
label=f"Epoch {epoch} end",
122-
step=epoch + 1,
123-
cpu_mem=cpu_mem,
124-
gpu_mem=gpu_mem
125-
)
126-
127-
def on_batch_end(self, batch, logs=None):
128-
"""If enabled, log memory usage at the end of each batch."""
129-
if self.log_every_batch:
130-
cpu_mem = self._get_cpu_memory()
131-
gpu_mem = self._get_gpu_memory()
132-
# Use the total batches seen count for a continuous TB step
133-
self._log_memory(
134-
label=f"Batch {self._total_batches_seen} end",
135-
step=self._total_batches_seen,
136-
cpu_mem=cpu_mem,
137-
gpu_mem=gpu_mem
138-
)
139-
# Always increment, even if not logging
140-
self._total_batches_seen += 1
58+
try:
59+
import tensorflow as tf
14160

142-
def on_train_end(self, logs=None):
143-
"""Clean up the TensorBoard writer."""
144-
if self.tb_writer:
145-
self.tb_writer.close()
146-
self.tb_writer = None
61+
logdir = os.path.expanduser(tensorboard_log_dir)
62+
self.tb_writer = tf.summary.create_file_writer(logdir)
63+
except ImportError as e:
64+
warnings.warn(f"TensorBoard disabled (no TF): {e}", RuntimeWarning)
65+
except Exception as e:
66+
warnings.warn(
67+
f"Failed to init TB writer at {tensorboard_log_dir}: {e}",
68+
RuntimeWarning,
69+
)
14770

14871
def _get_cpu_memory(self):
149-
"""Return current process CPU memory usage in MB."""
150-
return self.process.memory_info().rss / (1024 ** 2)
72+
"""Return resident set size in MB."""
73+
return self.process.memory_info().rss / (1024**2)
15174

15275
def _get_gpu_memory(self):
153-
"""Return current GPU memory usage in MB based on backend."""
76+
"""Return GPU memory usage in MB or None."""
15477
if not self.monitor_gpu:
15578
return None
156-
15779
backend = K.backend()
158-
gpu_mem_mb = None
15980
try:
16081
if backend == "tensorflow":
82+
import tensorflow as tf
83+
16184
gpus = tf.config.list_physical_devices("GPU")
162-
if not gpus: return None
163-
total_mem_bytes = 0
85+
if not gpus:
86+
return None
87+
total = 0
16488
for gpu in gpus:
165-
mem_info = tf.config.experimental.get_memory_info(gpu.name)
166-
total_mem_bytes += mem_info.get("current", 0)
167-
gpu_mem_mb = total_mem_bytes / (1024 ** 2)
168-
169-
elif backend == "torch":
170-
# Note: memory_allocated() tracks only tensors, might differ from TF.
171-
import torch
172-
if not torch.cuda.is_available(): return None
173-
# Sum memory allocated across all visible GPUs
174-
total_mem_bytes = sum(torch.cuda.memory_allocated(i) for i in range(torch.cuda.device_count()))
175-
gpu_mem_mb = total_mem_bytes / (1024 ** 2)
176-
177-
elif backend == "jax":
178-
# Note: JAX memory stats might also differ from TF/Torch in scope.
179-
import jax
180-
devices = jax.devices()
181-
gpu_devices = [d for d in devices if d.platform.upper() == 'GPU'] # Filter for GPU devices
182-
if not gpu_devices: return None
183-
total_mem_bytes = 0
184-
for device in gpu_devices:
185-
try:
186-
# memory_stats() might not be available or API could change
187-
stats = device.memory_stats()
188-
total_mem_bytes += stats.get("bytes_in_use", stats.get("allocated_bytes", 0)) # Try common keys
189-
except Exception:
190-
# Ignore if stats are unavailable for a device
191-
pass
192-
gpu_mem_mb = total_mem_bytes / (1024 ** 2)
193-
194-
else:
195-
if not hasattr(self, '_backend_warned'):
196-
warnings.warn(f"Unsupported backend '{backend}' for GPU memory monitoring.", RuntimeWarning)
197-
self._backend_warned = True
198-
return None
89+
info = tf.config.experimental.get_memory_info(gpu.name)
90+
total += info.get("current", 0)
91+
return total / (1024**2)
92+
93+
if backend == "torch":
94+
import torch
95+
96+
if not torch.cuda.is_available():
97+
return None
98+
total = sum(
99+
torch.cuda.memory_allocated(i)
100+
for i in range(torch.cuda.device_count())
101+
)
102+
return total / (1024**2)
103+
104+
if backend == "jax":
105+
import jax
106+
107+
devs = [d for d in jax.devices() if d.platform == "gpu"]
108+
if not devs:
109+
return None
110+
total = 0
111+
for d in devs:
112+
stats = getattr(d, "memory_stats", lambda: {})()
113+
total += stats.get("bytes_in_use", stats.get("allocated_bytes", 0))
114+
return total / (1024**2)
115+
116+
if not hasattr(self, "_warned_backend"):
117+
warnings.warn(
118+
f"Backend '{backend}' not supported for GPU memory.",
119+
RuntimeWarning,
120+
)
121+
self._warned_backend = True
122+
return None
199123

200124
except ImportError as e:
201-
# Backend library might not be installed
202-
if not hasattr(self, f'_{backend}_import_warned'):
203-
warnings.warn(f"MemoryUsageCallback: Could not import library for backend '{backend}': {e}. "
204-
f"GPU monitoring disabled for this backend.", RuntimeWarning)
205-
setattr(self, f'_{backend}_import_warned', True)
206-
return None
125+
warnings.warn(
126+
f"Could not import backend lib ({e}); GPU disabled.",
127+
RuntimeWarning,
128+
)
129+
return None
207130
except Exception as e:
208-
# Catch other potential errors during memory retrieval
209-
if not hasattr(self, f'_{backend}_error_warned'):
210-
warnings.warn(f"MemoryUsageCallback: Error retrieving GPU memory info for backend '{backend}': {e}", RuntimeWarning)
211-
setattr(self, f'_{backend}_error_warned', True)
131+
warnings.warn(f"Error retrieving GPU memory ({e}).", RuntimeWarning)
212132
return None
213133

214-
return gpu_mem_mb
134+
def _log(self, label, step):
135+
cpu = self._get_cpu_memory()
136+
gpu = self._get_gpu_memory()
137+
msg = f"{label} - CPU Memory: {cpu:.2f} MB"
138+
if gpu is not None:
139+
msg += f"; GPU Memory: {gpu:.2f} MB"
140+
print(msg)
141+
if self.tb_writer:
142+
import tensorflow as tf
143+
144+
with self.tb_writer.as_default(step=int(step)):
145+
tf.summary.scalar("Memory/CPU_MB", cpu)
146+
if gpu is not None:
147+
tf.summary.scalar("Memory/GPU_MB", gpu)
148+
self.tb_writer.flush()
215149

150+
def on_train_begin(self, logs=None):
151+
self._batch_count = 0
216152

217-
def _log_memory(self, label, step, cpu_mem, gpu_mem):
218-
"""Log memory metrics to stdout and potentially TensorBoard."""
219-
message = f"{label} - CPU Memory: {cpu_mem:.2f} MB"
220-
if gpu_mem is not None:
221-
message += f"; GPU Memory: {gpu_mem:.2f} MB"
222-
print(message) # Log to stdout
153+
def on_epoch_begin(self, epoch, logs=None):
154+
self._log(f"Epoch {epoch} start", epoch)
223155

224-
# Log to TensorBoard if writer is configured
156+
def on_epoch_end(self, epoch, logs=None):
157+
self._log(f"Epoch {epoch} end", epoch + 1)
158+
159+
def on_batch_end(self, batch, logs=None):
160+
if self.log_every_batch:
161+
self._log(f"Batch {self._batch_count} end", self._batch_count)
162+
self._batch_count += 1
163+
164+
def on_train_end(self, logs=None):
225165
if self.tb_writer:
226-
try:
227-
with self.tb_writer.as_default(step=int(step)):
228-
tf.summary.scalar("Memory/CPU_MB", cpu_mem)
229-
if gpu_mem is not None:
230-
tf.summary.scalar("Memory/GPU_MB", gpu_mem)
231-
self.tb_writer.flush()
232-
except Exception as e:
233-
# Catch potential errors during logging (e.g., writer closed unexpectedly)
234-
if not hasattr(self, '_tb_log_error_warned'):
235-
warnings.warn(f"MemoryUsageCallback: Error writing to TensorBoard: {e}", RuntimeWarning)
236-
self._tb_log_error_warned = True
237-
# Optionally disable writer if logging fails persistently
238-
# self.tb_writer = None
166+
self.tb_writer.close()

0 commit comments

Comments
 (0)