Skip to content

Commit 1ae7659

Browse files
DimiChatzipavlisDimiChatzipavlis
authored andcommitted
Reformatted code
1 parent 5af4a44 commit 1ae7659

File tree

3 files changed

+44
-21
lines changed

3 files changed

+44
-21
lines changed

keras/src/callbacks/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,3 @@
1414
from keras.src.callbacks.swap_ema_weights import SwapEMAWeights
1515
from keras.src.callbacks.tensorboard import TensorBoard
1616
from keras.src.callbacks.terminate_on_nan import TerminateOnNaN
17-

keras/src/callbacks/memory_usage_callback.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import os
22
import warnings
3+
4+
from keras.src import backend as K
35
from keras.src.api_export import keras_export
46
from keras.src.callbacks.callback import Callback
5-
from keras.src import backend as K
67

78
# Attempt to import psutil for memory monitoring
89
try:
910
import psutil
1011
except ImportError:
11-
psutil = None
12+
psutil = None
13+
1214

1315
@keras_export("keras.callbacks.MemoryUsageCallback")
1416
class MemoryUsageCallback(Callback):
@@ -63,14 +65,15 @@ def __init__(
6365

6466
if tensorboard_log_dir:
6567
try:
66-
import tensorflow as tf
68+
import tensorflow as tf
6769

6870
logdir = os.path.expanduser(tensorboard_log_dir)
6971
self._writer = tf.summary.create_file_writer(logdir)
7072
print(f"MemoryUsageCallback: TensorBoard logs → {logdir}")
7173
except Exception as e:
7274
warnings.warn(
73-
f"Could not initialize TensorBoard writer: {e}", RuntimeWarning
75+
f"Could not initialize TensorBoard writer: {e}",
76+
RuntimeWarning,
7477
)
7578
self._writer = None
7679

@@ -85,7 +88,9 @@ def on_epoch_end(self, epoch, logs=None):
8588

8689
def on_batch_end(self, batch, logs=None):
8790
if self.log_every_batch:
88-
self._log_step(f"Batch {self._step_counter} end", self._step_counter)
91+
self._log_step(
92+
f"Batch {self._step_counter} end", self._step_counter
93+
)
8994
self._step_counter += 1
9095

9196
def on_train_end(self, logs=None):
@@ -122,7 +127,7 @@ def _get_gpu_memory(self):
122127
backend_name = K.backend()
123128
try:
124129
if backend_name == "tensorflow":
125-
import tensorflow as tf
130+
import tensorflow as tf
126131

127132
gpus = tf.config.list_physical_devices("GPU")
128133
if not gpus:
@@ -173,6 +178,8 @@ def _get_gpu_memory(self):
173178
return None
174179
except Exception as exc:
175180
if not hasattr(self, "_warn_exc"):
176-
warnings.warn(f"Error retrieving GPU memory: {exc}", RuntimeWarning)
181+
warnings.warn(
182+
f"Error retrieving GPU memory: {exc}", RuntimeWarning
183+
)
177184
self._warn_exc = True
178185
return None

keras/src/callbacks/memory_usage_callback_test.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
import os
21
import glob
2+
import os
33
import re
44
import sys
55
import tempfile
6-
import pytest
7-
import numpy as np
8-
96
from contextlib import redirect_stdout
10-
from io import StringIO
117
from importlib import reload
12-
from unittest.mock import patch, MagicMock
8+
from io import StringIO
9+
from unittest.mock import MagicMock
10+
from unittest.mock import patch
1311

14-
from keras.src.models import Sequential
12+
import numpy as np
13+
import pytest
14+
15+
from keras.src.callbacks.memory_usage_callback import MemoryUsageCallback
1516
from keras.src.layers import Dense
17+
from keras.src.models import Sequential
1618
from keras.src.testing import TestCase
17-
from keras.src.callbacks.memory_usage_callback import MemoryUsageCallback
1819

1920
try:
2021
import psutil
@@ -69,13 +70,19 @@ def test_batch_logging_stdout(self):
6970
with redirect_stdout(buf):
7071
cb = MemoryUsageCallback(monitor_gpu=False, log_every_batch=True)
7172
self.model.fit(
72-
self.x, self.y, epochs=1, batch_size=self.bs, callbacks=[cb], verbose=0
73+
self.x,
74+
self.y,
75+
epochs=1,
76+
batch_size=self.bs,
77+
callbacks=[cb],
78+
verbose=0,
7379
)
7480
lines = buf.getvalue().splitlines()
7581
batch_lines = [l for l in lines if l.startswith("Batch ")]
7682
assert len(batch_lines) == self.steps
7783
assert all(
78-
re.match(r"Batch \d+ end - CPU Memory: [\d\.]+ MB", l) for l in batch_lines
84+
re.match(r"Batch \d+ end - CPU Memory: [\d\.]+ MB", l)
85+
for l in batch_lines
7986
)
8087

8188
@pytest.mark.requires_trainable_backend
@@ -85,9 +92,16 @@ def test_tensorboard_writes_files(self):
8592
logdir = os.path.join(tmp.name, "tb")
8693
buf = StringIO()
8794
with redirect_stdout(buf):
88-
cb = MemoryUsageCallback(monitor_gpu=False, tensorboard_log_dir=logdir)
95+
cb = MemoryUsageCallback(
96+
monitor_gpu=False, tensorboard_log_dir=logdir
97+
)
8998
self.model.fit(
90-
self.x, self.y, epochs=1, batch_size=self.bs, callbacks=[cb], verbose=0
99+
self.x,
100+
self.y,
101+
epochs=1,
102+
batch_size=self.bs,
103+
callbacks=[cb],
104+
verbose=0,
91105
)
92106
files = glob.glob(os.path.join(logdir, "events.out.tfevents.*"))
93107
assert files, "No TensorBoard event files generated"
@@ -118,7 +132,10 @@ def test_torch_backend_gpu_memory(monkeypatch):
118132
fake_torch = MagicMock()
119133
fake_torch.cuda.is_available.return_value = True
120134
fake_torch.cuda.device_count.return_value = 2
121-
fake_torch.cuda.memory_allocated.side_effect = [100 * 1024**2, 150 * 1024**2]
135+
fake_torch.cuda.memory_allocated.side_effect = [
136+
100 * 1024**2,
137+
150 * 1024**2,
138+
]
122139
monkeypatch.setitem(sys.modules, "torch", fake_torch)
123140

124141
cb = MemoryUsageCallback(monitor_gpu=True)

0 commit comments

Comments
 (0)