5
5
from keras .src .api_export import keras_export
6
6
from keras .src .callbacks .callback import Callback
7
7
8
- # Attempt to import psutil for memory monitoring
8
+ # Attempt to import psutil for CPU memory monitoring
9
9
try :
10
10
import psutil
11
11
except ImportError :
14
14
15
15
@keras_export ("keras.callbacks.MemoryUsageCallback" )
16
16
class MemoryUsageCallback (Callback ):
17
- """Monitors and logs memory usage (CPU + optional GPU/TPU) during training.
17
+ """Monitors and logs memory usage
18
+ (CPU + optional GPU/TPU/OpenVINO) during training.
18
19
19
20
This callback measures:
20
21
21
22
- **CPU**: via psutil.Process().memory_info().rss
22
- - **GPU/TPU**: via backend- specific APIs
23
- (TensorFlow, PyTorch, JAX, OpenVINO )
23
+ - **GPU/TPU**: via backend‐ specific APIs
24
+ (TensorFlow, PyTorch, JAX)
24
25
25
26
Logs are printed to stdout at the start/end of each epoch and,
26
- if `log_every_batch=True`, after every batch. If `tensorboard_log_dir`
27
- is provided, scalars are also written via `tf.summary` (TensorBoard).
27
+ if `log_every_batch=True`, after every batch.
28
+ If `tensorboard_log_dir` is provided, scalars are also written
29
+ via `tf.summary` (TensorBoard).
28
30
29
31
Args:
30
32
monitor_gpu (bool): If True, attempt to measure accelerator memory.
@@ -34,22 +36,13 @@ class MemoryUsageCallback(Callback):
34
36
35
37
Raises:
36
38
ImportError: If `psutil` is not installed (required for CPU logging).
37
-
38
- Example:
39
- ```python
40
- from keras.callbacks import MemoryUsageCallback
41
- # ...
42
- cb = MemoryUsageCallback(
43
- monitor_gpu=True,
44
- log_every_batch=False,
45
- tensorboard_log_dir="./logs/memory"
46
- )
47
- model.fit(X, y, callbacks=[cb])
48
- ```
49
39
"""
50
40
51
41
def __init__ (
52
- self , monitor_gpu = True , log_every_batch = False , tensorboard_log_dir = None
42
+ self ,
43
+ monitor_gpu = True ,
44
+ log_every_batch = False ,
45
+ tensorboard_log_dir = None ,
53
46
):
54
47
super ().__init__ ()
55
48
if psutil is None :
@@ -109,8 +102,7 @@ def _log_step(self, label, step):
109
102
msg = f"{ label } - CPU Memory: { cpu_mb :.2f} MB"
110
103
if gpu_mb is not None :
111
104
msg += f"; GPU Memory: { gpu_mb :.2f} MB"
112
- # newline + flush ensures clean, immediate output
113
- print ("\n " + msg , flush = True )
105
+ print (msg )
114
106
115
107
if self ._writer :
116
108
import tensorflow as tf # noqa: E501
@@ -119,7 +111,7 @@ def _log_step(self, label, step):
119
111
tf .summary .scalar ("Memory/CPU_MB" , cpu_mb )
120
112
if gpu_mb is not None :
121
113
tf .summary .scalar ("Memory/GPU_MB" , gpu_mb )
122
- self . _writer . flush ()
114
+ # flush happens inside writer
123
115
124
116
def _get_cpu_memory (self ):
125
117
return self ._proc .memory_info ().rss / (1024 ** 2 )
@@ -162,12 +154,21 @@ def _get_gpu_memory(self):
162
154
total += stats .get ("bytes_in_use" , 0 )
163
155
return total / (1024 ** 2 )
164
156
157
+ elif backend_name == "openvino" :
158
+ # OpenVINO provides no memory-stats API:
159
+ if not hasattr (self , "_warn_openvino" ):
160
+ warnings .warn (
161
+ " OpenVINO does not expose memory stats; "
162
+ "GPU monitoring disabled." ,
163
+ RuntimeWarning ,
164
+ )
165
+ self ._warn_openvino = True
166
+ return None
167
+
165
168
else :
166
- # OpenVINO and other unknown backends: warn once
167
169
if not hasattr (self , "_warn_backend" ):
168
170
warnings .warn (
169
- "MemoryUsageCallback: unsupported backend "
170
- f"'{ backend_name } '" ,
171
+ f"MemoryUsageCallback: no backend '{ backend_name } '" ,
171
172
RuntimeWarning ,
172
173
)
173
174
self ._warn_backend = True
0 commit comments