Skip to content

Commit 8cf740d

Browse files
leej3vfdev-5
andauthored
add fbresearch_logger.py (#3215)
* add fbresearch_logger.py Add FBResearchLogger class from unmerged branch object-detection-example Add minimal docs and tests * add some mypy fixes * fix docs bug * Update ignite/handlers/fbresearch_logger.py Co-authored-by: vfdev <vfdev.5@gmail.com> * fix type error * remove types from docstrings --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 42c4d29 commit 8cf740d

File tree

3 files changed

+244
-0
lines changed

3 files changed

+244
-0
lines changed

docs/source/handlers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Loggers
5454

5555
visdom_logger
5656
wandb_logger
57+
fbresearch_logger
5758

5859
.. seealso::
5960

ignite/handlers/fbresearch_logger.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""FBResearch logger and its helper handlers."""
2+
3+
import datetime
4+
from typing import Any, Optional
5+
6+
# from typing import Any, Dict, Optional, Union
7+
8+
import torch
9+
10+
from ignite.engine import Engine, Events
11+
from ignite.handlers import Timer
12+
13+
14+
MB = 1024.0 * 1024.0
15+
16+
17+
class FBResearchLogger:
18+
"""Logs training and validation metrics for research purposes.
19+
20+
This logger is designed to attach to an Ignite Engine and log various metrics
21+
and system stats at configurable intervals, including learning rates, iteration
22+
times, and GPU memory usage.
23+
24+
Args:
25+
logger: The logger to use for output.
26+
delimiter: The delimiter to use between metrics in the log output.
27+
show_output: Flag to enable logging of the output from the engine's process function.
28+
29+
Examples:
30+
.. code-block:: python
31+
32+
import logging
33+
from ignite.handlers.fbresearch_logger import *
34+
35+
logger = FBResearchLogger(logger=logging.Logger(__name__), show_output=True)
36+
logger.attach(trainer, name="Train", every=10, optimizer=my_optimizer)
37+
"""
38+
39+
def __init__(self, logger: Any, delimiter: str = " ", show_output: bool = False):
40+
self.delimiter = delimiter
41+
self.logger: Any = logger
42+
self.iter_timer: Timer = Timer(average=True)
43+
self.data_timer: Timer = Timer(average=True)
44+
self.show_output: bool = show_output
45+
46+
def attach(
47+
self, engine: Engine, name: str, every: int = 1, optimizer: Optional[torch.optim.Optimizer] = None
48+
) -> None:
49+
"""Attaches all the logging handlers to the given engine.
50+
51+
Args:
52+
engine: The engine to attach the logging handlers to.
53+
name: The name of the engine (e.g., "Train", "Validate") to include in log messages.
54+
every: Frequency of iterations to log information. Logs are generated every 'every' iterations.
55+
optimizer: The optimizer used during training to log current learning rates.
56+
"""
57+
engine.add_event_handler(Events.EPOCH_STARTED, self.log_epoch_started, engine, name)
58+
engine.add_event_handler(Events.ITERATION_COMPLETED(every=every), self.log_every, engine, optimizer=optimizer)
59+
engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name)
60+
engine.add_event_handler(Events.COMPLETED, self.log_completed, engine, name)
61+
62+
self.iter_timer.reset()
63+
self.iter_timer.attach(
64+
engine,
65+
start=Events.EPOCH_STARTED,
66+
resume=Events.ITERATION_STARTED,
67+
pause=Events.ITERATION_COMPLETED,
68+
step=Events.ITERATION_COMPLETED,
69+
)
70+
self.data_timer.reset()
71+
self.data_timer.attach(
72+
engine,
73+
start=Events.EPOCH_STARTED,
74+
resume=Events.GET_BATCH_STARTED,
75+
pause=Events.GET_BATCH_COMPLETED,
76+
step=Events.GET_BATCH_COMPLETED,
77+
)
78+
79+
def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] = None) -> None:
80+
"""
81+
Logs the training progress at regular intervals.
82+
83+
Args:
84+
engine: The training engine.
85+
optimizer: The optimizer used for training. Defaults to None.
86+
"""
87+
assert engine.state.epoch_length is not None
88+
cuda_max_mem = ""
89+
if torch.cuda.is_available():
90+
cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
91+
92+
current_iter = engine.state.iteration % (engine.state.epoch_length + 1)
93+
iter_avg_time = self.iter_timer.value()
94+
95+
eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)
96+
97+
outputs = []
98+
if self.show_output and engine.state.output is not None:
99+
output = engine.state.output
100+
if isinstance(output, dict):
101+
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
102+
else:
103+
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore
104+
105+
lrs = ""
106+
if optimizer is not None:
107+
if len(optimizer.param_groups) == 1:
108+
lrs += f"lr: {optimizer.param_groups[0]['lr']:.5f}"
109+
else:
110+
for i, g in enumerate(optimizer.param_groups):
111+
lrs += f"lr [g{i}]: {g['lr']:.5f}"
112+
113+
msg = self.delimiter.join(
114+
[
115+
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]",
116+
f"[{current_iter}/{engine.state.epoch_length}]:",
117+
f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}",
118+
f"{lrs}",
119+
]
120+
+ outputs
121+
+ [
122+
f"Iter time: {iter_avg_time:.4f} s",
123+
f"Data prep time: {self.data_timer.value():.4f} s",
124+
cuda_max_mem,
125+
]
126+
)
127+
self.logger.info(msg)
128+
129+
def log_epoch_started(self, engine: Engine, name: str) -> None:
130+
"""
131+
Logs the start of an epoch.
132+
133+
Args:
134+
engine: The engine object.
135+
name: The name of the epoch.
136+
137+
"""
138+
msg = f"{name}: start epoch [{engine.state.epoch}/{engine.state.max_epochs}]"
139+
self.logger.info(msg)
140+
141+
def log_epoch_completed(self, engine: Engine, name: str) -> None:
142+
"""
143+
Logs the completion of an epoch.
144+
145+
Args:
146+
engine: The engine object that triggered the event.
147+
name: The name of the event.
148+
149+
Returns:
150+
None
151+
"""
152+
epoch_time = engine.state.times[Events.EPOCH_COMPLETED.name]
153+
epoch_info = (
154+
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]"
155+
if engine.state.max_epochs > 1 # type: ignore
156+
else ""
157+
)
158+
msg = self.delimiter.join(
159+
[
160+
f"{name}: {epoch_info}",
161+
f"Total time: {datetime.timedelta(seconds=int(epoch_time))}", # type: ignore
162+
f"({epoch_time / engine.state.epoch_length:.4f} s / it)", # type: ignore
163+
]
164+
)
165+
self.logger.info(msg)
166+
167+
def log_completed(self, engine: Engine, name: str) -> None:
168+
"""
169+
Logs the completion of a run.
170+
171+
Args:
172+
engine: The engine object representing the training/validation loop.
173+
name: The name of the run.
174+
175+
"""
176+
if engine.state.max_epochs and engine.state.max_epochs > 1:
177+
total_time = engine.state.times[Events.COMPLETED.name]
178+
assert total_time is not None
179+
msg = self.delimiter.join(
180+
[
181+
f"{name}: run completed",
182+
f"Total time: {datetime.timedelta(seconds=int(total_time))}",
183+
]
184+
)
185+
self.logger.info(msg)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import logging
2+
import re
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
7+
from ignite.engine import Engine, Events
8+
from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary
9+
10+
11+
@pytest.fixture
12+
def mock_engine():
13+
engine = Engine(lambda e, b: None)
14+
engine.state.epoch = 1
15+
engine.state.max_epochs = 10
16+
engine.state.epoch_length = 100
17+
engine.state.iteration = 50
18+
return engine
19+
20+
21+
@pytest.fixture
22+
def mock_logger():
23+
return MagicMock(spec=logging.Logger)
24+
25+
26+
@pytest.fixture
27+
def fb_research_logger(mock_logger):
28+
yield FBResearchLogger(logger=mock_logger, show_output=True)
29+
30+
31+
def test_fbresearch_logger_initialization(mock_logger):
32+
logger = FBResearchLogger(logger=mock_logger, show_output=True)
33+
assert logger.logger == mock_logger
34+
assert logger.show_output is True
35+
36+
37+
def test_fbresearch_logger_attach(mock_engine, mock_logger):
38+
logger = FBResearchLogger(logger=mock_logger, show_output=True)
39+
logger.attach(mock_engine, name="Test", every=1)
40+
assert mock_engine.has_event_handler(logger.log_every, Events.ITERATION_COMPLETED)
41+
42+
43+
@pytest.mark.parametrize(
44+
"output,expected_pattern",
45+
[
46+
({"loss": 0.456, "accuracy": 0.789}, r"loss. *0.456.*accuracy. *0.789"),
47+
((0.456, 0.789), r"0.456.*0.789"),
48+
([0.456, 0.789], r"0.456.*0.789"),
49+
],
50+
)
51+
def test_output_formatting(mock_engine, fb_research_logger, output, expected_pattern):
52+
# Ensure the logger correctly formats and logs the output for each type
53+
mock_engine.state.output = output
54+
fb_research_logger.attach(mock_engine, name="Test", every=1)
55+
mock_engine.fire_event(Events.ITERATION_COMPLETED)
56+
57+
actual_output = fb_research_logger.logger.info.call_args_list[0].args[0]
58+
assert re.search(expected_pattern, actual_output)

0 commit comments

Comments
 (0)