@@ -42,6 +42,7 @@ def __init__(
4242
4343 # Members that get reset even when loading pretrained networks
4444 self .pre_epoch_callback : Callable [[IFF , int ], Any ] | None = None
45+ self .pre_batch_callback : Callable [[IFF , int , int ], Any ] | None = None
4546
4647 logger .info ("Created FFSuite." )
4748
@@ -68,6 +69,9 @@ def callback(net: IFF, e: int):
6869
6970 self .pre_epoch_callback = callback
7071
72+ def set_pre_batch_callback (self , callback : Callable [[IFF , int , int ], Any ]) -> None :
73+ self .pre_batch_callback = callback
74+
7175 def run_test_epoch (self , loader : DataLoader [Any ]) -> float :
7276 self .net .eval ()
7377 test_correct : int = 0
@@ -96,12 +100,15 @@ def run_train_epoch(self, validate: bool = True) -> None:
96100 if self .pre_epoch_callback is not None :
97101 self .pre_epoch_callback (self .net , self .current_epoch )
98102
99- for b in tqdm (loaders ["train" ]):
103+ for i , b in tqdm (enumerate ( loaders ["train" ]), total = len ( loaders [ "train" ]) ):
100104 batch : Tuple [torch .Tensor , torch .Tensor ] = b
101105 x , y = batch
102106 if self .device is not None :
103107 x , y = x .to (self .device ), y .to (self .device )
104108
109+ if self .pre_batch_callback is not None :
110+ self .pre_batch_callback (self .net , self .current_epoch , i )
111+
105112 self ._train (x , y )
106113
107114 # Validation phase
0 commit comments