@@ -19,11 +19,11 @@ class MemoryUsageCallback(Callback):
19
19
This callback measures:
20
20
21
21
- **CPU**: via psutil.Process().memory_info().rss
22
- - **GPU/TPU**: via backend‐ specific APIs
23
- (TensorFlow, PyTorch, JAX, OpenVINO)
22
+ - **GPU/TPU**: via backend- specific APIs
23
+ (TensorFlow, PyTorch, JAX, OpenVINO)
24
24
25
25
Logs are printed to stdout at the start/end of each epoch and,
26
- if `log_every_batch=True`, after every batch. If `tensorboard_log_dir`
26
+ if `log_every_batch=True`, after every batch. If `tensorboard_log_dir`
27
27
is provided, scalars are also written via `tf.summary` (TensorBoard).
28
28
29
29
Args:
@@ -36,7 +36,6 @@ class MemoryUsageCallback(Callback):
36
36
ImportError: If `psutil` is not installed (required for CPU logging).
37
37
38
38
Example:
39
-
40
39
```python
41
40
from keras.callbacks import MemoryUsageCallback
42
41
# ...
@@ -46,7 +45,7 @@ class MemoryUsageCallback(Callback):
46
45
tensorboard_log_dir="./logs/memory"
47
46
)
48
47
model.fit(X, y, callbacks=[cb])
49
- ```
48
+ ```
50
49
"""
51
50
52
51
def __init__ (
@@ -67,7 +66,6 @@ def __init__(
67
66
if tensorboard_log_dir :
68
67
try :
69
68
import tensorflow as tf
70
-
71
69
logdir = os .path .expanduser (tensorboard_log_dir )
72
70
self ._writer = tf .summary .create_file_writer (logdir )
73
71
print (f"MemoryUsageCallback: TensorBoard logs → { logdir } " )
@@ -89,9 +87,7 @@ def on_epoch_end(self, epoch, logs=None):
89
87
90
88
def on_batch_end (self , batch , logs = None ):
91
89
if self .log_every_batch :
92
- self ._log_step (
93
- f"Batch { self ._step_counter } end" , self ._step_counter
94
- )
90
+ self ._log_step (f"Batch { self ._step_counter } end" , self ._step_counter )
95
91
self ._step_counter += 1
96
92
97
93
def on_train_end (self , logs = None ):
@@ -110,65 +106,66 @@ def _log_step(self, label, step):
110
106
msg = f"{ label } - CPU Memory: { cpu_mb :.2f} MB"
111
107
if gpu_mb is not None :
112
108
msg += f"; GPU Memory: { gpu_mb :.2f} MB"
113
- print (msg )
109
+ # newline + flush ensures clean, immediate output
110
+ print ("\n " + msg , flush = True )
114
111
115
112
if self ._writer :
116
113
import tensorflow as tf # noqa: E501
117
-
118
114
with self ._writer .as_default (step = int (step )):
119
115
tf .summary .scalar ("Memory/CPU_MB" , cpu_mb )
120
116
if gpu_mb is not None :
121
117
tf .summary .scalar ("Memory/GPU_MB" , gpu_mb )
122
118
self ._writer .flush ()
123
119
124
120
def _get_cpu_memory (self ):
125
- return self ._proc .memory_info ().rss / (1024 ** 2 )
121
+ return self ._proc .memory_info ().rss / (1024 ** 2 )
126
122
127
123
def _get_gpu_memory (self ):
128
124
backend_name = K .backend ()
129
125
try :
130
126
if backend_name == "tensorflow" :
131
127
import tensorflow as tf
132
-
133
128
gpus = tf .config .list_physical_devices ("GPU" )
134
129
if not gpus :
135
130
return None
136
131
total = sum (
137
132
tf .config .experimental .get_memory_info (g .name )["current" ]
138
133
for g in gpus
139
134
)
140
- return total / (1024 ** 2 )
135
+ return total / (1024 ** 2 )
136
+
141
137
elif backend_name == "torch" :
142
138
import torch
143
-
144
139
if not torch .cuda .is_available ():
145
140
return None
146
141
total = sum (
147
142
torch .cuda .memory_allocated (i )
148
143
for i in range (torch .cuda .device_count ())
149
144
)
150
- return total / (1024 ** 2 )
145
+ return total / (1024 ** 2 )
146
+
151
147
elif backend_name == "jax" :
152
148
import jax
153
-
154
149
devs = [d for d in jax .devices () if d .platform .upper () == "GPU" ]
155
150
if not devs :
156
151
return None
157
152
total = 0
158
153
for d in devs :
159
154
stats = getattr (d , "memory_stats" , lambda : {})()
160
155
total += stats .get ("bytes_in_use" , 0 )
161
- return total / (1024 ** 2 )
162
- else :
163
- # OpenVINO and others fall back to unsupported
156
+ return total / (1024 ** 2 )
164
157
158
+ else :
159
+ # OpenVINO and other unknown backends: warn once
165
160
if not hasattr (self , "_warn_backend" ):
166
161
warnings .warn (
167
- f"MemoryUsageCallback: unsupported backend '{ backend_name } '" ,
162
+ "MemoryUsageCallback: unsupported backend "
163
+ f"'{ backend_name } '" ,
168
164
RuntimeWarning ,
169
165
)
170
166
self ._warn_backend = True
171
167
return None
168
+
172
169
except ImportError as imp_err :
173
170
if not hasattr (self , "_warn_import" ):
174
171
warnings .warn (
@@ -177,6 +174,7 @@ def _get_gpu_memory(self):
177
174
)
178
175
self ._warn_import = True
179
176
return None
177
+
180
178
except Exception as exc :
181
179
if not hasattr (self , "_warn_exc" ):
182
180
warnings .warn (
0 commit comments