Skip to content

Commit fe59889

Browse files
committed
Add prebatch callback.
1 parent 8390d52 commit fe59889

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/fflib/utils/iff_suite.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242

4343
# Members that get reset even when loading pretrained networks
4444
self.pre_epoch_callback: Callable[[IFF, int], Any] | None = None
45+
self.pre_batch_callback: Callable[[IFF, int, int], Any] | None = None
4546

4647
logger.info("Created FFSuite.")
4748

@@ -68,6 +69,9 @@ def callback(net: IFF, e: int):
6869

6970
self.pre_epoch_callback = callback
7071

72+
def set_pre_batch_callback(self, callback: Callable[[IFF, int, int], Any]) -> None:
73+
self.pre_batch_callback = callback
74+
7175
def run_test_epoch(self, loader: DataLoader[Any]) -> float:
7276
self.net.eval()
7377
test_correct: int = 0
@@ -96,12 +100,15 @@ def run_train_epoch(self, validate: bool = True) -> None:
96100
if self.pre_epoch_callback is not None:
97101
self.pre_epoch_callback(self.net, self.current_epoch)
98102

99-
for b in tqdm(loaders["train"]):
103+
for i, b in tqdm(enumerate(loaders["train"]), total=len(loaders["train"])):
100104
batch: Tuple[torch.Tensor, torch.Tensor] = b
101105
x, y = batch
102106
if self.device is not None:
103107
x, y = x.to(self.device), y.to(self.device)
104108

109+
if self.pre_batch_callback is not None:
110+
self.pre_batch_callback(self.net, self.current_epoch, i)
111+
105112
self._train(x, y)
106113

107114
# Validation phase

0 commit comments

Comments
 (0)