Skip to content

Commit daddf29

Browse files
Fix formatting errors
1 parent 5f9d975 commit daddf29

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

keras/src/callbacks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from keras.src.callbacks.history import History
77
from keras.src.callbacks.lambda_callback import LambdaCallback
88
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
9+
from keras.src.callbacks.memory_usage_callback import MemoryUsageCallback
910
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
1011
from keras.src.callbacks.progbar_logger import ProgbarLogger
1112
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
1213
from keras.src.callbacks.remote_monitor import RemoteMonitor
1314
from keras.src.callbacks.swap_ema_weights import SwapEMAWeights
1415
from keras.src.callbacks.tensorboard import TensorBoard
1516
from keras.src.callbacks.terminate_on_nan import TerminateOnNaN
16-
from keras.src.callbacks.memory_usage_callback import MemoryUsageCallback
17+

keras/src/callbacks/memory_usage_callback.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras.src import backend as K
77

88
# Attempt to import psutil for CPU memory
9+
910
try:
1011
import psutil
1112
except ImportError:
@@ -56,7 +57,7 @@ def __init__(
5657

5758
if tensorboard_log_dir:
5859
try:
59-
import tensorflow as tf
60+
import tensorflow as tf
6061

6162
logdir = os.path.expanduser(tensorboard_log_dir)
6263
self.tb_writer = tf.summary.create_file_writer(logdir)
@@ -79,7 +80,7 @@ def _get_gpu_memory(self):
7980
backend = K.backend()
8081
try:
8182
if backend == "tensorflow":
82-
import tensorflow as tf
83+
import tensorflow as tf
8384

8485
gpus = tf.config.list_physical_devices("GPU")
8586
if not gpus:
@@ -89,9 +90,8 @@ def _get_gpu_memory(self):
8990
info = tf.config.experimental.get_memory_info(gpu.name)
9091
total += info.get("current", 0)
9192
return total / (1024**2)
92-
9393
if backend == "torch":
94-
import torch
94+
import torch
9595

9696
if not torch.cuda.is_available():
9797
return None
@@ -100,9 +100,8 @@ def _get_gpu_memory(self):
100100
for i in range(torch.cuda.device_count())
101101
)
102102
return total / (1024**2)
103-
104103
if backend == "jax":
105-
import jax
104+
import jax
106105

107106
devs = [d for d in jax.devices() if d.platform == "gpu"]
108107
if not devs:
@@ -112,15 +111,13 @@ def _get_gpu_memory(self):
112111
stats = getattr(d, "memory_stats", lambda: {})()
113112
total += stats.get("bytes_in_use", stats.get("allocated_bytes", 0))
114113
return total / (1024**2)
115-
116114
if not hasattr(self, "_warned_backend"):
117115
warnings.warn(
118116
f"Backend '{backend}' not supported for GPU memory.",
119117
RuntimeWarning,
120118
)
121119
self._warned_backend = True
122120
return None
123-
124121
except ImportError as e:
125122
warnings.warn(
126123
f"Could not import backend lib ({e}); GPU disabled.",
@@ -139,7 +136,7 @@ def _log(self, label, step):
139136
msg += f"; GPU Memory: {gpu:.2f} MB"
140137
print(msg)
141138
if self.tb_writer:
142-
import tensorflow as tf
139+
import tensorflow as tf
143140

144141
with self.tb_writer.as_default(step=int(step)):
145142
tf.summary.scalar("Memory/CPU_MB", cpu)

keras/src/callbacks/memory_usage_callback_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import glob
33
import tempfile
4-
import warnings
54

65
from contextlib import redirect_stdout
76
from io import StringIO
@@ -18,6 +17,7 @@
1817
from keras.src import backend as K
1918

2019
# Skip all tests if psutil is not installed
20+
2121
try:
2222
import psutil
2323
except ImportError:
@@ -48,6 +48,7 @@ def test_epoch_and_batch_stdout(self):
4848
out = StringIO()
4949
with redirect_stdout(out):
5050
# Mock GPU memory for predictability
51+
5152
with patch.object(
5253
MemoryUsageCallback, "_get_gpu_memory", return_value=42.0
5354
):
@@ -62,20 +63,24 @@ def test_epoch_and_batch_stdout(self):
6263
)
6364
log = out.getvalue().splitlines()
6465
# Check epoch logs
66+
6567
for i in range(self.epochs):
6668
assert any(f"Epoch {i} start" in line for line in log)
6769
assert any(f"Epoch {i} end" in line for line in log)
6870
# Check batch logs count
71+
6972
batch_lines = [l for l in log if l.startswith("Batch")]
7073
assert len(batch_lines) == self.total_batches
7174
# Confirm GPU part present
75+
7276
assert any("GPU Memory: 42.00 MB" in l for l in log)
7377

7478
@pytest.mark.requires_trainable_backend
7579
def test_tensorboard_file_creation(self):
7680
with tempfile.TemporaryDirectory() as tmpdir:
7781
tb_dir = os.path.join(tmpdir, "tb")
7882
# Mock CPU/GPU memory
83+
7984
with patch.object(
8085
MemoryUsageCallback, "_get_gpu_memory", return_value=10.0
8186
), patch.object(MemoryUsageCallback, "_get_cpu_memory", return_value=5.0):
@@ -106,12 +111,15 @@ def test_import_error_without_psutil(self):
106111
reload(mod)
107112
_ = mod.MemoryUsageCallback()
108113
# restore
114+
109115
if orig is not None:
110116
sys.modules["psutil"] = orig
111117
reload(mod)
112118

113119

114120
# Backend-specific tests
121+
122+
115123
@pytest.mark.requires_trainable_backend
116124
def test_torch_gpu_memory(monkeypatch):
117125
monkeypatch.setattr(K, "backend", lambda: "torch")

0 commit comments

Comments
 (0)