@@ -45,7 +45,7 @@ class MemoryUsageCallback(Callback):
45
45
tensorboard_log_dir="./logs/memory"
46
46
)
47
47
model.fit(X, y, callbacks=[cb])
48
- ```
48
+ ```
49
49
"""
50
50
51
51
def __init__ (
@@ -66,6 +66,7 @@ def __init__(
66
66
if tensorboard_log_dir :
67
67
try :
68
68
import tensorflow as tf
69
+
69
70
logdir = os .path .expanduser (tensorboard_log_dir )
70
71
self ._writer = tf .summary .create_file_writer (logdir )
71
72
print (f"MemoryUsageCallback: TensorBoard logs → { logdir } " )
@@ -87,7 +88,9 @@ def on_epoch_end(self, epoch, logs=None):
87
88
88
89
def on_batch_end (self , batch , logs = None ):
89
90
if self .log_every_batch :
90
- self ._log_step (f"Batch { self ._step_counter } end" , self ._step_counter )
91
+ self ._log_step (
92
+ f"Batch { self ._step_counter } end" , self ._step_counter
93
+ )
91
94
self ._step_counter += 1
92
95
93
96
def on_train_end (self , logs = None ):
@@ -111,49 +114,53 @@ def _log_step(self, label, step):
111
114
112
115
if self ._writer :
113
116
import tensorflow as tf # noqa: E501
117
+
114
118
with self ._writer .as_default (step = int (step )):
115
119
tf .summary .scalar ("Memory/CPU_MB" , cpu_mb )
116
120
if gpu_mb is not None :
117
121
tf .summary .scalar ("Memory/GPU_MB" , gpu_mb )
118
122
self ._writer .flush ()
119
123
120
124
def _get_cpu_memory (self ):
121
- return self ._proc .memory_info ().rss / (1024 ** 2 )
125
+ return self ._proc .memory_info ().rss / (1024 ** 2 )
122
126
123
127
def _get_gpu_memory (self ):
124
128
backend_name = K .backend ()
125
129
try :
126
130
if backend_name == "tensorflow" :
127
131
import tensorflow as tf
132
+
128
133
gpus = tf .config .list_physical_devices ("GPU" )
129
134
if not gpus :
130
135
return None
131
136
total = sum (
132
137
tf .config .experimental .get_memory_info (g .name )["current" ]
133
138
for g in gpus
134
139
)
135
- return total / (1024 ** 2 )
140
+ return total / (1024 ** 2 )
136
141
137
142
elif backend_name == "torch" :
138
143
import torch
144
+
139
145
if not torch .cuda .is_available ():
140
146
return None
141
147
total = sum (
142
148
torch .cuda .memory_allocated (i )
143
149
for i in range (torch .cuda .device_count ())
144
150
)
145
- return total / (1024 ** 2 )
151
+ return total / (1024 ** 2 )
146
152
147
153
elif backend_name == "jax" :
148
154
import jax
155
+
149
156
devs = [d for d in jax .devices () if d .platform .upper () == "GPU" ]
150
157
if not devs :
151
158
return None
152
159
total = 0
153
160
for d in devs :
154
161
stats = getattr (d , "memory_stats" , lambda : {})()
155
162
total += stats .get ("bytes_in_use" , 0 )
156
- return total / (1024 ** 2 )
163
+ return total / (1024 ** 2 )
157
164
158
165
else :
159
166
# OpenVINO and other unknown backends: warn once
0 commit comments