5
5
from keras .src .callbacks .callback import Callback
6
6
from keras .src import backend as K
7
7
8
- # Attempt to import psutil for CPU memory
9
-
10
8
try :
11
9
import psutil
12
10
except ImportError :
15
13
16
14
@keras_export ("keras.callbacks.MemoryUsageCallback" )
17
15
class MemoryUsageCallback (Callback ):
18
- """
19
- Monitors CPU and GPU memory across backends and logs to stdout and TensorBoard.
20
-
21
- Example:
22
- ```python
23
- from keras.callbacks import MemoryUsageCallback
24
- callback = MemoryUsageCallback(
25
- monitor_gpu=True,
26
- log_every_batch=False,
27
- tensorboard_log_dir="./logs"
28
- )
29
- model.fit(..., callbacks=[callback])
30
- ```
16
+ """Monitor CPU/GPU/TPU/OpenVINO memory during training.
17
+
18
+ Tracks:
19
+ - CPU memory via `psutil.Process().memory_info().rss`.
20
+ - GPU memory via backend APIs (TF, Torch, JAX, OpenVINO).
21
+ - Logs to stdout and, optionally, to TensorBoard.
31
22
32
23
Args:
33
- monitor_gpu (bool): Whether to log GPU memory. Defaults to True .
34
- log_every_batch (bool): Whether to log after every batch. Defaults to False .
35
- tensorboard_log_dir ( str): Directory for TensorBoard logs; None disables. Defaults to None .
24
+ monitor_gpu: Bool. If True, query GPU/accelerator memory.
25
+ log_every_batch: Bool. If True, log after each batch.
26
+ tensorboard_log_dir: str or None. If set, use TF summary writer .
36
27
37
28
Raises:
38
- ImportError: If psutil is not installed .
29
+ ImportError: If ` psutil` is missing .
39
30
"""
40
31
41
32
def __init__ (
42
- self ,
43
- monitor_gpu = True ,
44
- log_every_batch = False ,
45
- tensorboard_log_dir = None ,
33
+ self , monitor_gpu = True , log_every_batch = False , tensorboard_log_dir = None
46
34
):
47
35
super ().__init__ ()
48
36
if psutil is None :
49
- raise ImportError (
50
- "MemoryUsageCallback requires `psutil`; install via `pip install psutil`."
51
- )
37
+ raise ImportError ("MemoryUsageCallback requires the 'psutil' library." )
52
38
self .monitor_gpu = monitor_gpu
53
39
self .log_every_batch = log_every_batch
54
40
self .process = psutil .Process ()
55
41
self .tb_writer = None
56
- self ._batch_count = 0
42
+ self ._batches_seen = 0
57
43
58
44
if tensorboard_log_dir :
59
45
try :
60
- import tensorflow as tf
46
+ import tensorflow as tf
61
47
62
48
logdir = os .path .expanduser (tensorboard_log_dir )
63
49
self .tb_writer = tf .summary .create_file_writer (logdir )
64
- except ImportError as e :
65
- warnings .warn (f"TensorBoard disabled (no TF): { e } " , RuntimeWarning )
66
50
except Exception as e :
67
- warnings .warn (
68
- f"Failed to init TB writer at { tensorboard_log_dir } : { e } " ,
69
- RuntimeWarning ,
70
- )
51
+ warnings .warn (f"TB init error: { e } " , RuntimeWarning )
52
+
53
+ def on_train_begin (self , logs = None ):
54
+ self ._batches_seen = 0
55
+
56
+ def on_epoch_begin (self , epoch , logs = None ):
57
+ cpu = self ._cpu_mem_mb ()
58
+ gpu = self ._get_gpu_memory ()
59
+ self ._log ("Epoch %d start" % epoch , epoch , cpu , gpu )
60
+
61
+ def on_epoch_end (self , epoch , logs = None ):
62
+ cpu = self ._cpu_mem_mb ()
63
+ gpu = self ._get_gpu_memory ()
64
+ self ._log ("Epoch %d end" % epoch , epoch + 1 , cpu , gpu )
65
+
66
+ def on_batch_end (self , batch , logs = None ):
67
+ if self .log_every_batch :
68
+ cpu = self ._cpu_mem_mb ()
69
+ gpu = self ._get_gpu_memory ()
70
+ self ._log (f"Batch { self ._batches_seen } end" , self ._batches_seen , cpu , gpu )
71
+ self ._batches_seen += 1
72
+
73
+ def on_train_end (self , logs = None ):
74
+ if self .tb_writer :
75
+ self .tb_writer .close ()
71
76
72
- def _get_cpu_memory (self ):
73
- """Return resident set size in MB."""
77
+ def _cpu_mem_mb (self ):
74
78
return self .process .memory_info ().rss / (1024 ** 2 )
75
79
76
80
def _get_gpu_memory (self ):
77
- """Return GPU memory usage in MB or None."""
78
81
if not self .monitor_gpu :
79
82
return None
80
83
backend = K .backend ()
81
84
try :
82
85
if backend == "tensorflow" :
83
- import tensorflow as tf
86
+ import tensorflow as tf
84
87
85
88
gpus = tf .config .list_physical_devices ("GPU" )
86
89
if not gpus :
87
90
return None
88
91
total = 0
89
- for gpu in gpus :
90
- info = tf .config .experimental .get_memory_info (gpu .name )
92
+ for g in gpus :
93
+ info = tf .config .experimental .get_memory_info (g .name )
91
94
total += info .get ("current" , 0 )
92
95
return total / (1024 ** 2 )
93
96
if backend == "torch" :
94
- import torch
97
+ import torch
95
98
96
99
if not torch .cuda .is_available ():
97
100
return None
@@ -101,63 +104,46 @@ def _get_gpu_memory(self):
101
104
)
102
105
return total / (1024 ** 2 )
103
106
if backend == "jax" :
104
- import jax
107
+ import jax
105
108
106
- devs = [d for d in jax .devices () if d .platform == "gpu " ]
109
+ devs = [d for d in jax .devices () if d .platform . upper () == "GPU " ]
107
110
if not devs :
108
111
return None
109
112
total = 0
110
113
for d in devs :
111
- stats = getattr ( d , " memory_stats" , lambda : {}) ()
112
- total += stats .get ("bytes_in_use" , stats . get ( "allocated_bytes" , 0 ) )
114
+ stats = d . memory_stats ()
115
+ total += stats .get ("bytes_in_use" , 0 )
113
116
return total / (1024 ** 2 )
114
- if not hasattr (self , "_warned_backend" ):
115
- warnings .warn (
116
- f"Backend '{ backend } ' not supported for GPU memory." ,
117
- RuntimeWarning ,
118
- )
119
- self ._warned_backend = True
120
- return None
117
+ if backend == "openvino" :
118
+ try :
119
+ import openvino as ov
120
+
121
+ core = ov .Core ()
122
+ devices = core .available_devices
123
+ total = 0
124
+ for dev in devices :
125
+ stats = core .get_property (dev , "DEVICE_MEMORY_STATISTICS" )
126
+ total += stats .get ("deviceUsedBytes" , 0 )
127
+ return total / (1024 ** 2 )
128
+ except Exception as e :
129
+ warnings .warn (f"OVINO mem err: { e } " , RuntimeWarning )
130
+ return None
121
131
except ImportError as e :
122
- warnings .warn (
123
- f"Could not import backend lib ({ e } ); GPU disabled." ,
124
- RuntimeWarning ,
125
- )
126
- return None
127
- except Exception as e :
128
- warnings .warn (f"Error retrieving GPU memory ({ e } )." , RuntimeWarning )
132
+ warnings .warn (f"Import err for { backend } : { e } " , RuntimeWarning )
129
133
return None
134
+ warnings .warn (f"Unsupported backend '{ backend } '" , RuntimeWarning )
135
+ return None
130
136
131
- def _log (self , label , step ):
132
- cpu = self ._get_cpu_memory ()
133
- gpu = self ._get_gpu_memory ()
134
- msg = f"{ label } - CPU Memory: { cpu :.2f} MB"
137
+ def _log (self , label , step , cpu , gpu ):
138
+ msg = f"{ label } - CPU: { cpu :.2f} MB"
135
139
if gpu is not None :
136
- msg += f"; GPU Memory : { gpu :.2f} MB"
140
+ msg += f"; GPU: { gpu :.2f} MB"
137
141
print (msg )
138
142
if self .tb_writer :
139
- import tensorflow as tf
143
+ import tensorflow as tf
140
144
141
- with self .tb_writer .as_default (step = int ( step ) ):
145
+ with self .tb_writer .as_default (step = step ):
142
146
tf .summary .scalar ("Memory/CPU_MB" , cpu )
143
147
if gpu is not None :
144
148
tf .summary .scalar ("Memory/GPU_MB" , gpu )
145
149
self .tb_writer .flush ()
146
-
147
- def on_train_begin (self , logs = None ):
148
- self ._batch_count = 0
149
-
150
- def on_epoch_begin (self , epoch , logs = None ):
151
- self ._log (f"Epoch { epoch } start" , epoch )
152
-
153
- def on_epoch_end (self , epoch , logs = None ):
154
- self ._log (f"Epoch { epoch } end" , epoch + 1 )
155
-
156
- def on_batch_end (self , batch , logs = None ):
157
- if self .log_every_batch :
158
- self ._log (f"Batch { self ._batch_count } end" , self ._batch_count )
159
- self ._batch_count += 1
160
-
161
- def on_train_end (self , logs = None ):
162
- if self .tb_writer :
163
- self .tb_writer .close ()
0 commit comments