Skip to content

Commit c4c0e5e

Browse files
DimiChatzipavlisDimiChatzipavlis
DimiChatzipavlis
authored and
DimiChatzipavlis
committed
Format the code (2)
1 parent 728b770 commit c4c0e5e

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

keras/src/callbacks/memory_usage_callback.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class MemoryUsageCallback(Callback):
4545
tensorboard_log_dir="./logs/memory"
4646
)
4747
model.fit(X, y, callbacks=[cb])
48-
```
48+
```
4949
"""
5050

5151
def __init__(
@@ -66,6 +66,7 @@ def __init__(
6666
if tensorboard_log_dir:
6767
try:
6868
import tensorflow as tf
69+
6970
logdir = os.path.expanduser(tensorboard_log_dir)
7071
self._writer = tf.summary.create_file_writer(logdir)
7172
print(f"MemoryUsageCallback: TensorBoard logs → {logdir}")
@@ -87,7 +88,9 @@ def on_epoch_end(self, epoch, logs=None):
8788

8889
def on_batch_end(self, batch, logs=None):
8990
if self.log_every_batch:
90-
self._log_step(f"Batch {self._step_counter} end", self._step_counter)
91+
self._log_step(
92+
f"Batch {self._step_counter} end", self._step_counter
93+
)
9194
self._step_counter += 1
9295

9396
def on_train_end(self, logs=None):
@@ -111,49 +114,53 @@ def _log_step(self, label, step):
111114

112115
if self._writer:
113116
import tensorflow as tf # noqa: E501
117+
114118
with self._writer.as_default(step=int(step)):
115119
tf.summary.scalar("Memory/CPU_MB", cpu_mb)
116120
if gpu_mb is not None:
117121
tf.summary.scalar("Memory/GPU_MB", gpu_mb)
118122
self._writer.flush()
119123

120124
def _get_cpu_memory(self):
121-
return self._proc.memory_info().rss / (1024 ** 2)
125+
return self._proc.memory_info().rss / (1024**2)
122126

123127
def _get_gpu_memory(self):
124128
backend_name = K.backend()
125129
try:
126130
if backend_name == "tensorflow":
127131
import tensorflow as tf
132+
128133
gpus = tf.config.list_physical_devices("GPU")
129134
if not gpus:
130135
return None
131136
total = sum(
132137
tf.config.experimental.get_memory_info(g.name)["current"]
133138
for g in gpus
134139
)
135-
return total / (1024 ** 2)
140+
return total / (1024**2)
136141

137142
elif backend_name == "torch":
138143
import torch
144+
139145
if not torch.cuda.is_available():
140146
return None
141147
total = sum(
142148
torch.cuda.memory_allocated(i)
143149
for i in range(torch.cuda.device_count())
144150
)
145-
return total / (1024 ** 2)
151+
return total / (1024**2)
146152

147153
elif backend_name == "jax":
148154
import jax
155+
149156
devs = [d for d in jax.devices() if d.platform.upper() == "GPU"]
150157
if not devs:
151158
return None
152159
total = 0
153160
for d in devs:
154161
stats = getattr(d, "memory_stats", lambda: {})()
155162
total += stats.get("bytes_in_use", 0)
156-
return total / (1024 ** 2)
163+
return total / (1024**2)
157164

158165
else:
159166
# OpenVINO and other unknown backends: warn once

0 commit comments

Comments
 (0)