Skip to content

Commit 0aca27e

Browse files
committed
step based logging.
1 parent cd3ee78 commit 0aca27e

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

train.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@
6464
has_functorch = True
6565
except ImportError as e:
6666
has_functorch = False
67-
67+
#test tensorboard install
68+
try:
69+
from torch.utils.tensorboard import SummaryWriter
70+
has_tensorboard = True
71+
except ImportError as e:
72+
has_tensorboard = False
6873
has_compile = hasattr(torch, 'compile')
6974

7075

@@ -347,8 +352,8 @@
347352
help='use the multi-epochs-loader to save time at the beginning of every epoch')
348353
group.add_argument('--log-wandb', action='store_true', default=False,
349354
help='log training and validation metrics to wandb')
350-
351-
355+
group.add_argument('--log-tensorboard', default='', type=str, metavar='PATH',
356+
help='log training and validation metrics to TensorBoard')
352357
def _parse_args():
353358
# Do we have a config file to parse?
354359
args_config, remaining = config_parser.parse_known_args()
@@ -726,6 +731,16 @@ def main():
726731
"You've requested to log metrics to wandb but package not found. "
727732
"Metrics not being logged to wandb, try `pip install wandb`")
728733

734+
if utils.is_primary(args) and args.log_tensorboard:
735+
if has_tensorboard:
736+
writer = SummaryWriter(args.log_tensorboard)
737+
else:
738+
_logger.warning(
739+
"You've requested to log metrics to tensorboard but package not found. "
740+
"Metrics not being logged to tensorboard, try `pip install tensorboard`")
741+
742+
743+
729744
# setup learning rate schedule and starting epoch
730745
updates_per_epoch = len(loader_train)
731746
lr_scheduler, num_epochs = create_scheduler_v2(
@@ -809,6 +824,7 @@ def main():
809824
lr=sum(lrs) / len(lrs),
810825
write_header=best_metric is None,
811826
log_wandb=args.log_wandb and has_wandb,
827+
812828
)
813829

814830
if saver is not None:
@@ -903,6 +919,10 @@ def train_one_epoch(
903919

904920
num_updates += 1
905921
batch_time_m.update(time.time() - end)
922+
#write to tensorboard if enabled
923+
if should_log_to_tensorboard(args):
924+
writer.add_scalar('train/loss', losses_m.val, num_updates)
925+
writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates)
906926
if last_batch or batch_idx % args.log_interval == 0:
907927
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
908928
lr = sum(lrl) / len(lrl)
@@ -954,6 +974,10 @@ def train_one_epoch(
954974
return OrderedDict([('loss', losses_m.avg)])
955975

956976

977+
def should_log_to_tensorboard(args):
978+
return args.log_tensorboard and utils.is_primary(args) and has_tensorboard
979+
980+
957981
def validate(
958982
model,
959983
loader,
@@ -1011,6 +1035,10 @@ def validate(
10111035

10121036
batch_time_m.update(time.time() - end)
10131037
end = time.time()
1038+
if should_log_to_tensorboard(args):
1039+
writer.add_scalar('val/loss', losses_m.val, batch_idx)
1040+
writer.add_scalar('val/acc1', top1_m.val, batch_idx)
1041+
writer.add_scalar('val/acc5', top5_m.val, batch_idx)
10141042
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
10151043
log_name = 'Test' + log_suffix
10161044
_logger.info(

0 commit comments

Comments
 (0)