1
1
import os
2
2
import warnings
3
- import tensorflow as tf # Ensure TF is imported for tf.summary
4
- from keras .src import backend as K
3
+
5
4
from keras .src .api_export import keras_export
6
5
from keras .src .callbacks .callback import Callback
6
+ from keras .src import backend as K
7
7
8
- # Attempt to import psutil and warn if unavailable.
8
+ # Attempt to import psutil for CPU memory
9
9
try :
10
10
import psutil
11
11
except ImportError :
12
12
psutil = None
13
13
14
+
14
15
@keras_export ("keras.callbacks.MemoryUsageCallback" )
15
16
class MemoryUsageCallback (Callback ):
16
- """Callback for enhanced monitoring of memory usage during training.
17
-
18
- This callback tracks CPU memory usage via `psutil.Process().memory_info().rss`
19
- and optionally GPU memory usage via backend-specific APIs (TensorFlow, PyTorch, JAX).
20
- Memory statistics are logged to stdout at the start and end of each epoch and,
21
- optionally, after every batch. Additionally, metrics are logged to TensorBoard
22
- if a log directory is provided, using integer steps for proper visualization.
23
-
24
- Note: GPU memory reporting consistency across backends (TF, PyTorch, JAX)
25
- may vary, as they use different underlying mechanisms to measure usage
26
- (e.g., framework overhead vs. purely tensor allocations).
27
-
28
- Args:
29
- monitor_gpu (bool): Whether to monitor GPU memory. Defaults to True.
30
- Requires appropriate backend (TensorFlow, PyTorch, JAX) with GPU
31
- support and necessary drivers/libraries installed.
32
- log_every_batch (bool): Whether to log memory usage after each batch
33
- in addition to epoch start/end. Defaults to False.
34
- tensorboard_log_dir (str, optional): Path to the directory where TensorBoard
35
- logs will be written using `tf.summary`. If None, TensorBoard logging
36
- is disabled. Defaults to None. Requires TensorFlow to be installed.
37
-
38
- Raises:
39
- ImportError: If `psutil` is not installed.
17
+ """
18
+ Monitors CPU and GPU memory across backends and logs to stdout and TensorBoard.
40
19
41
20
Example:
42
21
```python
43
- import tensorflow as tf
44
- import keras
45
- from keras.callbacks import MemoryUsageCallback # Use public API path
46
- import numpy as np
47
-
48
- # Ensure psutil is installed: pip install psutil
49
-
50
- memory_callback = MemoryUsageCallback(
51
- monitor_gpu=True, # Set based on GPU availability and backend support
22
+ from keras.callbacks import MemoryUsageCallback
23
+ callback = MemoryUsageCallback(
24
+ monitor_gpu=True,
52
25
log_every_batch=False,
53
- tensorboard_log_dir="~ /logs/memory_usage" # Needs TF installed
26
+ tensorboard_log_dir=". /logs"
54
27
)
28
+ model.fit(..., callbacks=[callback])
29
+ ```
55
30
56
- model = keras.models.Sequential([
57
- keras.layers.Dense(64, activation='relu', input_shape=(100,)),
58
- keras.layers.Dense(10, activation='softmax')
59
- ])
60
- model.compile(optimizer='adam', loss='categorical_crossentropy')
31
+ Args:
32
+ monitor_gpu (bool): Whether to log GPU memory. Defaults to True.
33
+ log_every_batch (bool): Whether to log after every batch. Defaults to False.
34
+ tensorboard_log_dir (str): Directory for TensorBoard logs; None disables. Defaults to None.
61
35
62
- x_train = np.random.random((100, 100))
63
- y_train = keras.utils.to_categorical(
64
- np.random.randint(10, size=(100, 1)), num_classes=10
65
- )
66
- model.fit(x_train, y_train, epochs=5, batch_size=32, callbacks=[memory_callback])
67
- ```
36
+ Raises:
37
+ ImportError: If psutil is not installed.
68
38
"""
69
- def __init__ (self , monitor_gpu = True , log_every_batch = False , tensorboard_log_dir = None ):
39
+
40
+ def __init__ (
41
+ self ,
42
+ monitor_gpu = True ,
43
+ log_every_batch = False ,
44
+ tensorboard_log_dir = None ,
45
+ ):
70
46
super ().__init__ ()
71
47
if psutil is None :
72
48
raise ImportError (
73
- "MemoryUsageCallback requires the 'psutil' library. "
74
- "Please install it using 'pip install psutil'."
49
+ "MemoryUsageCallback requires `psutil`; install via `pip install psutil`."
75
50
)
76
51
self .monitor_gpu = monitor_gpu
77
52
self .log_every_batch = log_every_batch
78
53
self .process = psutil .Process ()
79
54
self .tb_writer = None
80
- self ._total_batches_seen = 0 # For TensorBoard step counting
55
+ self ._batch_count = 0
81
56
82
57
if tensorboard_log_dir :
83
- # tf.summary requires TensorFlow installed.
84
- if tf is None :
85
- warnings .warn (
86
- "MemoryUsageCallback: TensorFlow is required for TensorBoard logging. "
87
- "Please install TensorFlow." , ImportWarning
88
- )
89
- self .tb_writer = None
90
- else :
91
- try :
92
- log_dir = os .path .expanduser (tensorboard_log_dir )
93
- # Use tf.summary for robust integration
94
- self .tb_writer = tf .summary .create_file_writer (log_dir )
95
- print (f"MemoryUsageCallback: TensorBoard logging initialized at { log_dir } " )
96
- except Exception as e :
97
- warnings .warn (f"Error initializing TensorBoard writer: { e } " , RuntimeWarning )
98
- self .tb_writer = None
99
-
100
- def on_train_begin (self , logs = None ):
101
- """Reset batch counter at the start of training."""
102
- self ._total_batches_seen = 0
103
-
104
- def on_epoch_begin (self , epoch , logs = None ):
105
- """Log memory usage at the beginning of each epoch."""
106
- cpu_mem = self ._get_cpu_memory ()
107
- gpu_mem = self ._get_gpu_memory ()
108
- self ._log_memory (
109
- label = f"Epoch { epoch } start" ,
110
- step = epoch , # Use epoch number for TB step
111
- cpu_mem = cpu_mem ,
112
- gpu_mem = gpu_mem
113
- )
114
-
115
- def on_epoch_end (self , epoch , logs = None ):
116
- """Log memory usage at the end of each epoch."""
117
- cpu_mem = self ._get_cpu_memory ()
118
- gpu_mem = self ._get_gpu_memory ()
119
- # Use epoch + 1 for TB step to mark the end point distinctly
120
- self ._log_memory (
121
- label = f"Epoch { epoch } end" ,
122
- step = epoch + 1 ,
123
- cpu_mem = cpu_mem ,
124
- gpu_mem = gpu_mem
125
- )
126
-
127
- def on_batch_end (self , batch , logs = None ):
128
- """If enabled, log memory usage at the end of each batch."""
129
- if self .log_every_batch :
130
- cpu_mem = self ._get_cpu_memory ()
131
- gpu_mem = self ._get_gpu_memory ()
132
- # Use the total batches seen count for a continuous TB step
133
- self ._log_memory (
134
- label = f"Batch { self ._total_batches_seen } end" ,
135
- step = self ._total_batches_seen ,
136
- cpu_mem = cpu_mem ,
137
- gpu_mem = gpu_mem
138
- )
139
- # Always increment, even if not logging
140
- self ._total_batches_seen += 1
58
+ try :
59
+ import tensorflow as tf
141
60
142
- def on_train_end (self , logs = None ):
143
- """Clean up the TensorBoard writer."""
144
- if self .tb_writer :
145
- self .tb_writer .close ()
146
- self .tb_writer = None
61
+ logdir = os .path .expanduser (tensorboard_log_dir )
62
+ self .tb_writer = tf .summary .create_file_writer (logdir )
63
+ except ImportError as e :
64
+ warnings .warn (f"TensorBoard disabled (no TF): { e } " , RuntimeWarning )
65
+ except Exception as e :
66
+ warnings .warn (
67
+ f"Failed to init TB writer at { tensorboard_log_dir } : { e } " ,
68
+ RuntimeWarning ,
69
+ )
147
70
148
71
def _get_cpu_memory (self ):
149
- """Return current process CPU memory usage in MB."""
150
- return self .process .memory_info ().rss / (1024 ** 2 )
72
+ """Return resident set size in MB."""
73
+ return self .process .memory_info ().rss / (1024 ** 2 )
151
74
152
75
def _get_gpu_memory (self ):
153
- """Return current GPU memory usage in MB based on backend ."""
76
+ """Return GPU memory usage in MB or None ."""
154
77
if not self .monitor_gpu :
155
78
return None
156
-
157
79
backend = K .backend ()
158
- gpu_mem_mb = None
159
80
try :
160
81
if backend == "tensorflow" :
82
+ import tensorflow as tf
83
+
161
84
gpus = tf .config .list_physical_devices ("GPU" )
162
- if not gpus : return None
163
- total_mem_bytes = 0
85
+ if not gpus :
86
+ return None
87
+ total = 0
164
88
for gpu in gpus :
165
- mem_info = tf .config .experimental .get_memory_info (gpu .name )
166
- total_mem_bytes += mem_info .get ("current" , 0 )
167
- gpu_mem_mb = total_mem_bytes / (1024 ** 2 )
168
-
169
- elif backend == "torch" :
170
- # Note: memory_allocated() tracks only tensors, might differ from TF.
171
- import torch
172
- if not torch .cuda .is_available (): return None
173
- # Sum memory allocated across all visible GPUs
174
- total_mem_bytes = sum (torch . cuda . memory_allocated ( i ) for i in range ( torch . cuda . device_count ()))
175
- gpu_mem_mb = total_mem_bytes / ( 1024 ** 2 )
176
-
177
- elif backend == "jax" :
178
- # Note: JAX memory stats might also differ from TF/Torch in scope.
179
- import jax
180
- devices = jax . devices ()
181
- gpu_devices = [ d for d in devices if d . platform . upper () == 'GPU' ] # Filter for GPU devices
182
- if not gpu_devices : return None
183
- total_mem_bytes = 0
184
- for device in gpu_devices :
185
- try :
186
- # memory_stats() might not be available or API could change
187
- stats = device . memory_stats ()
188
- total_mem_bytes += stats . get ( "bytes_in_use" , stats . get ( "allocated_bytes " , 0 )) # Try common keys
189
- except Exception :
190
- # Ignore if stats are unavailable for a device
191
- pass
192
- gpu_mem_mb = total_mem_bytes / ( 1024 ** 2 )
193
-
194
- else :
195
- if not hasattr ( self , '_backend_warned' ):
196
- warnings . warn ( f"Unsupported backend ' { backend } ' for GPU memory monitoring." , RuntimeWarning )
197
- self ._backend_warned = True
198
- return None
89
+ info = tf .config .experimental .get_memory_info (gpu .name )
90
+ total += info .get ("current" , 0 )
91
+ return total / (1024 ** 2 )
92
+
93
+ if backend == "torch" :
94
+ import torch
95
+
96
+ if not torch .cuda .is_available ():
97
+ return None
98
+ total = sum (
99
+ torch . cuda . memory_allocated ( i )
100
+ for i in range ( torch . cuda . device_count ())
101
+ )
102
+ return total / ( 1024 ** 2 )
103
+
104
+ if backend == " jax" :
105
+ import jax
106
+
107
+ devs = [ d for d in jax . devices () if d . platform == "gpu" ]
108
+ if not devs :
109
+ return None
110
+ total = 0
111
+ for d in devs :
112
+ stats = getattr ( d , "memory_stats " , lambda : {})()
113
+ total += stats . get ( "bytes_in_use" , stats . get ( "allocated_bytes" , 0 ))
114
+ return total / ( 1024 ** 2 )
115
+
116
+ if not hasattr ( self , "_warned_backend" ):
117
+ warnings . warn (
118
+ f"Backend ' { backend } ' not supported for GPU memory." ,
119
+ RuntimeWarning ,
120
+ )
121
+ self ._warned_backend = True
122
+ return None
199
123
200
124
except ImportError as e :
201
- # Backend library might not be installed
202
- if not hasattr (self , f'_{ backend } _import_warned' ):
203
- warnings .warn (f"MemoryUsageCallback: Could not import library for backend '{ backend } ': { e } . "
204
- f"GPU monitoring disabled for this backend." , RuntimeWarning )
205
- setattr (self , f'_{ backend } _import_warned' , True )
206
- return None
125
+ warnings .warn (
126
+ f"Could not import backend lib ({ e } ); GPU disabled." ,
127
+ RuntimeWarning ,
128
+ )
129
+ return None
207
130
except Exception as e :
208
- # Catch other potential errors during memory retrieval
209
- if not hasattr (self , f'_{ backend } _error_warned' ):
210
- warnings .warn (f"MemoryUsageCallback: Error retrieving GPU memory info for backend '{ backend } ': { e } " , RuntimeWarning )
211
- setattr (self , f'_{ backend } _error_warned' , True )
131
+ warnings .warn (f"Error retrieving GPU memory ({ e } )." , RuntimeWarning )
212
132
return None
213
133
214
- return gpu_mem_mb
134
+ def _log (self , label , step ):
135
+ cpu = self ._get_cpu_memory ()
136
+ gpu = self ._get_gpu_memory ()
137
+ msg = f"{ label } - CPU Memory: { cpu :.2f} MB"
138
+ if gpu is not None :
139
+ msg += f"; GPU Memory: { gpu :.2f} MB"
140
+ print (msg )
141
+ if self .tb_writer :
142
+ import tensorflow as tf
143
+
144
+ with self .tb_writer .as_default (step = int (step )):
145
+ tf .summary .scalar ("Memory/CPU_MB" , cpu )
146
+ if gpu is not None :
147
+ tf .summary .scalar ("Memory/GPU_MB" , gpu )
148
+ self .tb_writer .flush ()
215
149
150
+ def on_train_begin (self , logs = None ):
151
+ self ._batch_count = 0
216
152
217
- def _log_memory (self , label , step , cpu_mem , gpu_mem ):
218
- """Log memory metrics to stdout and potentially TensorBoard."""
219
- message = f"{ label } - CPU Memory: { cpu_mem :.2f} MB"
220
- if gpu_mem is not None :
221
- message += f"; GPU Memory: { gpu_mem :.2f} MB"
222
- print (message ) # Log to stdout
153
+ def on_epoch_begin (self , epoch , logs = None ):
154
+ self ._log (f"Epoch { epoch } start" , epoch )
223
155
224
- # Log to TensorBoard if writer is configured
156
+ def on_epoch_end (self , epoch , logs = None ):
157
+ self ._log (f"Epoch { epoch } end" , epoch + 1 )
158
+
159
+ def on_batch_end (self , batch , logs = None ):
160
+ if self .log_every_batch :
161
+ self ._log (f"Batch { self ._batch_count } end" , self ._batch_count )
162
+ self ._batch_count += 1
163
+
164
+ def on_train_end (self , logs = None ):
225
165
if self .tb_writer :
226
- try :
227
- with self .tb_writer .as_default (step = int (step )):
228
- tf .summary .scalar ("Memory/CPU_MB" , cpu_mem )
229
- if gpu_mem is not None :
230
- tf .summary .scalar ("Memory/GPU_MB" , gpu_mem )
231
- self .tb_writer .flush ()
232
- except Exception as e :
233
- # Catch potential errors during logging (e.g., writer closed unexpectedly)
234
- if not hasattr (self , '_tb_log_error_warned' ):
235
- warnings .warn (f"MemoryUsageCallback: Error writing to TensorBoard: { e } " , RuntimeWarning )
236
- self ._tb_log_error_warned = True
237
- # Optionally disable writer if logging fails persistently
238
- # self.tb_writer = None
166
+ self .tb_writer .close ()
0 commit comments