1
1
import os
2
2
import warnings
3
-
4
3
from keras .src .api_export import keras_export
5
4
from keras .src .callbacks .callback import Callback
6
5
from keras .src import backend as K
7
6
8
7
try :
9
8
import psutil
10
9
except ImportError :
11
- psutil = None
12
-
10
+ psutil = None
13
11
14
12
@keras_export ("keras.callbacks.MemoryUsageCallback" )
15
13
class MemoryUsageCallback (Callback ):
16
- """Monitor CPU/GPU/TPU/OpenVINO memory during training.
14
+ """Monitors and logs memory usage (CPU + optional GPU/TPU) during training.
15
+
16
+ This callback measures:
17
17
18
- Tracks:
19
- - CPU memory via `psutil.Process().memory_info().rss`.
20
- - GPU memory via backend APIs (TF, Torch, JAX, OpenVINO).
21
- - Logs to stdout and, optionally, to TensorBoard.
18
+ - **CPU**: via psutil.Process().memory_info().rss
19
+ - **GPU/TPU**: via backend‐specific APIs (TensorFlow, PyTorch, JAX, OpenVINO)
20
+
21
+ Logs are printed to stdout at the start/end of each epoch and,
22
+ if `log_every_batch=True`, after every batch. If `tensorboard_log_dir`
23
+ is provided, scalars are also written via `tf.summary` (TensorBoard).
22
24
23
25
Args:
24
- monitor_gpu: Bool. If True, query GPU/accelerator memory.
25
- log_every_batch: Bool. If True, log after each batch.
26
- tensorboard_log_dir: str or None. If set, use TF summary writer.
26
+ monitor_gpu (bool): If True, attempt to measure accelerator memory.
27
+ log_every_batch (bool): If True, also log after each batch.
28
+ tensorboard_log_dir (str|None): Directory for TensorBoard logs;
29
+ if None, no TF summary writer is created.
27
30
28
31
Raises:
29
- ImportError: If `psutil` is missing.
32
+ ImportError: If `psutil` is not installed (required for CPU logging).
33
+
34
+ Example:
35
+
36
+ ```python
37
+ from keras.callbacks import MemoryUsageCallback
38
+ # ...
39
+ cb = MemoryUsageCallback(
40
+ monitor_gpu=True,
41
+ log_every_batch=False,
42
+ tensorboard_log_dir="./logs/memory"
43
+ )
44
+ model.fit(X, y, callbacks=[cb])
45
+ ```
30
46
"""
31
47
32
48
def __init__ (
33
49
self , monitor_gpu = True , log_every_batch = False , tensorboard_log_dir = None
34
50
):
35
51
super ().__init__ ()
36
52
if psutil is None :
37
- raise ImportError ("MemoryUsageCallback requires the 'psutil' library." )
53
+ raise ImportError (
54
+ "MemoryUsageCallback requires the 'psutil' library. "
55
+ "Install via `pip install psutil`."
56
+ )
38
57
self .monitor_gpu = monitor_gpu
39
58
self .log_every_batch = log_every_batch
40
- self .process = psutil .Process ()
41
- self .tb_writer = None
42
- self ._batches_seen = 0
59
+ self ._proc = psutil .Process ()
60
+ self ._step_counter = 0
61
+ self ._writer = None
43
62
44
63
if tensorboard_log_dir :
45
64
try :
46
- import tensorflow as tf
65
+ import tensorflow as tf
47
66
48
67
logdir = os .path .expanduser (tensorboard_log_dir )
49
- self .tb_writer = tf .summary .create_file_writer (logdir )
68
+ self ._writer = tf .summary .create_file_writer (logdir )
69
+ print (f"MemoryUsageCallback: TensorBoard logs → { logdir } " )
50
70
except Exception as e :
51
- warnings .warn (f"TB init error: { e } " , RuntimeWarning )
71
+ warnings .warn (
72
+ f"Could not initialize TensorBoard writer: { e } " , RuntimeWarning
73
+ )
74
+ self ._writer = None
52
75
53
76
def on_train_begin (self , logs = None ):
54
- self ._batches_seen = 0
77
+ self ._step_counter = 0
55
78
56
79
def on_epoch_begin (self , epoch , logs = None ):
57
- cpu = self ._cpu_mem_mb ()
58
- gpu = self ._get_gpu_memory ()
59
- self ._log ("Epoch %d start" % epoch , epoch , cpu , gpu )
80
+ self ._log_epoch ("start" , epoch )
60
81
61
82
def on_epoch_end (self , epoch , logs = None ):
62
- cpu = self ._cpu_mem_mb ()
63
- gpu = self ._get_gpu_memory ()
64
- self ._log ("Epoch %d end" % epoch , epoch + 1 , cpu , gpu )
83
+ self ._log_epoch ("end" , epoch , offset = 1 )
65
84
66
85
def on_batch_end (self , batch , logs = None ):
67
86
if self .log_every_batch :
68
- cpu = self ._cpu_mem_mb ()
69
- gpu = self ._get_gpu_memory ()
70
- self ._log (f"Batch { self ._batches_seen } end" , self ._batches_seen , cpu , gpu )
71
- self ._batches_seen += 1
87
+ self ._log_step (f"Batch { self ._step_counter } end" , self ._step_counter )
88
+ self ._step_counter += 1
72
89
73
90
def on_train_end (self , logs = None ):
74
- if self .tb_writer :
75
- self .tb_writer .close ()
91
+ if self ._writer :
92
+ self ._writer .close ()
93
+
94
+ def _log_epoch (self , when , epoch , offset = 0 ):
95
+ label = f"Epoch { epoch } { when } "
96
+ step = epoch + offset
97
+ self ._log_step (label , step )
98
+
99
+ def _log_step (self , label , step ):
100
+ cpu_mb = self ._get_cpu_memory ()
101
+ gpu_mb = self ._get_gpu_memory () if self .monitor_gpu else None
102
+
103
+ msg = f"{ label } - CPU Memory: { cpu_mb :.2f} MB"
104
+ if gpu_mb is not None :
105
+ msg += f"; GPU Memory: { gpu_mb :.2f} MB"
106
+ print (msg )
107
+
108
+ if self ._writer :
109
+ import tensorflow as tf # noqa: E501
76
110
77
- def _cpu_mem_mb (self ):
78
- return self .process .memory_info ().rss / (1024 ** 2 )
111
+ with self ._writer .as_default (step = int (step )):
112
+ tf .summary .scalar ("Memory/CPU_MB" , cpu_mb )
113
+ if gpu_mb is not None :
114
+ tf .summary .scalar ("Memory/GPU_MB" , gpu_mb )
115
+ self ._writer .flush ()
116
+
117
+ def _get_cpu_memory (self ):
118
+ return self ._proc .memory_info ().rss / (1024 ** 2 )
79
119
80
120
def _get_gpu_memory (self ):
81
- if not self .monitor_gpu :
82
- return None
83
- backend = K .backend ()
121
+ backend_name = K .backend ()
84
122
try :
85
- if backend == "tensorflow" :
123
+ if backend_name == "tensorflow" :
86
124
import tensorflow as tf
87
125
88
126
gpus = tf .config .list_physical_devices ("GPU" )
89
127
if not gpus :
90
128
return None
91
- total = 0
92
- for g in gpus :
93
- info = tf . config . experimental . get_memory_info ( g . name )
94
- total += info . get ( "current" , 0 )
129
+ total = sum (
130
+ tf . config . experimental . get_memory_info ( g . name )[ "current" ]
131
+ for g in gpus
132
+ )
95
133
return total / (1024 ** 2 )
96
- if backend == "torch" :
97
- import torch
134
+ elif backend_name == "torch" :
135
+ import torch
98
136
99
137
if not torch .cuda .is_available ():
100
138
return None
@@ -103,47 +141,37 @@ def _get_gpu_memory(self):
103
141
for i in range (torch .cuda .device_count ())
104
142
)
105
143
return total / (1024 ** 2 )
106
- if backend == "jax" :
107
- import jax
144
+ elif backend_name == "jax" :
145
+ import jax
108
146
109
147
devs = [d for d in jax .devices () if d .platform .upper () == "GPU" ]
110
148
if not devs :
111
149
return None
112
150
total = 0
113
151
for d in devs :
114
- stats = d . memory_stats ()
152
+ stats = getattr ( d , " memory_stats" , lambda : {}) ()
115
153
total += stats .get ("bytes_in_use" , 0 )
116
154
return total / (1024 ** 2 )
117
- if backend == "openvino" :
118
- try :
119
- import openvino as ov
120
-
121
- core = ov .Core ()
122
- devices = core .available_devices
123
- total = 0
124
- for dev in devices :
125
- stats = core .get_property (dev , "DEVICE_MEMORY_STATISTICS" )
126
- total += stats .get ("deviceUsedBytes" , 0 )
127
- return total / (1024 ** 2 )
128
- except Exception as e :
129
- warnings .warn (f"OVINO mem err: { e } " , RuntimeWarning )
130
- return None
131
- except ImportError as e :
132
- warnings .warn (f"Import err for { backend } : { e } " , RuntimeWarning )
155
+ else :
156
+ # OpenVINO and others fall back to unsupported
157
+
158
+ if not hasattr (self , "_warn_backend" ):
159
+ warnings .warn (
160
+ f"MemoryUsageCallback: unsupported backend '{ backend_name } '" ,
161
+ RuntimeWarning ,
162
+ )
163
+ self ._warn_backend = True
164
+ return None
165
+ except ImportError as imp_err :
166
+ if not hasattr (self , "_warn_import" ):
167
+ warnings .warn (
168
+ f"Could not import for backend '{ backend_name } ': { imp_err } " ,
169
+ RuntimeWarning ,
170
+ )
171
+ self ._warn_import = True
172
+ return None
173
+ except Exception as exc :
174
+ if not hasattr (self , "_warn_exc" ):
175
+ warnings .warn (f"Error retrieving GPU memory: { exc } " , RuntimeWarning )
176
+ self ._warn_exc = True
133
177
return None
134
- warnings .warn (f"Unsupported backend '{ backend } '" , RuntimeWarning )
135
- return None
136
-
137
- def _log (self , label , step , cpu , gpu ):
138
- msg = f"{ label } - CPU: { cpu :.2f} MB"
139
- if gpu is not None :
140
- msg += f"; GPU: { gpu :.2f} MB"
141
- print (msg )
142
- if self .tb_writer :
143
- import tensorflow as tf
144
-
145
- with self .tb_writer .as_default (step = step ):
146
- tf .summary .scalar ("Memory/CPU_MB" , cpu )
147
- if gpu is not None :
148
- tf .summary .scalar ("Memory/GPU_MB" , gpu )
149
- self .tb_writer .flush ()
0 commit comments