From 0aca27e3cb9af4dc36013eec9711f19818cb41da Mon Sep 17 00:00:00 2001 From: exx8 Date: Tue, 14 Mar 2023 22:46:20 +0200 Subject: [PATCH 1/7] step based logging. --- train.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 816f4ae804..c7dfd7eda1 100755 --- a/train.py +++ b/train.py @@ -64,7 +64,12 @@ has_functorch = True except ImportError as e: has_functorch = False - +#test tensorboard install +try: + from torch.utils.tensorboard import SummaryWriter + has_tensorboard = True +except ImportError as e: + has_tensorboard = False has_compile = hasattr(torch, 'compile') @@ -347,8 +352,8 @@ help='use the multi-epochs-loader to save time at the beginning of every epoch') group.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') - - +group.add_argument('--log-tensorboard', default='', type=str, metavar='PATH', + help='log training and validation metrics to TensorBoard') def _parse_args(): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() @@ -726,6 +731,16 @@ def main(): "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") + if utils.is_primary(args) and args.log_tensorboard: + if has_tensorboard: + writer = SummaryWriter(args.log_tensorboard) + else: + _logger.warning( + "You've requested to log metrics to tensorboard but package not found. " + "Metrics not being logged to tensorboard, try `pip install tensorboard`") + + + # setup learning rate schedule and starting epoch updates_per_epoch = len(loader_train) lr_scheduler, num_epochs = create_scheduler_v2( @@ -809,6 +824,7 @@ def main(): lr=sum(lrs) / len(lrs), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb, + ) if saver is not None: @@ -903,6 +919,10 @@ def train_one_epoch( num_updates += 1 batch_time_m.update(time.time() - end) + #write to tensorboard if enabled + if should_log_to_tensorboard(args): + writer.add_scalar('train/loss', losses_m.val, num_updates) + writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) @@ -954,6 +974,10 @@ def train_one_epoch( return OrderedDict([('loss', losses_m.avg)]) +def should_log_to_tensorboard(args): + return args.log_tensorboard and utils.is_primary(args) and has_tensorboard + + def validate( model, loader, @@ -1011,6 +1035,10 @@ def validate( batch_time_m.update(time.time() - end) end = time.time() + if should_log_to_tensorboard(args): + writer.add_scalar('val/loss', losses_m.val, batch_idx) + writer.add_scalar('val/acc1', top1_m.val, batch_idx) + writer.add_scalar('val/acc5', top5_m.val, batch_idx) if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( From 905c55f36fa028983ddb66fd7fa4d16dfe51f85a Mon Sep 17 00:00:00 2001 From: exx8 Date: Tue, 14 Mar 2023 22:48:49 +0200 Subject: [PATCH 2/7] bug fix --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index c7dfd7eda1..84284e2890 100755 --- a/train.py +++ b/train.py @@ -731,7 +731,7 @@ def main(): "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") - if utils.is_primary(args) and args.log_tensorboard: + if should_log_to_tensorboard(args): if has_tensorboard: writer = SummaryWriter(args.log_tensorboard) else: From 1c2d40101df9d056c5beef495128a31343d7d9cf Mon Sep 17 00:00:00 2001 From: exx8 Date: Tue, 14 Mar 2023 22:51:15 +0200 Subject: [PATCH 3/7] fix --- train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 84284e2890..20dd021ce1 100755 --- a/train.py +++ b/train.py @@ -730,10 +730,9 @@ def main(): _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") - if should_log_to_tensorboard(args): if has_tensorboard: - writer = SummaryWriter(args.log_tensorboard) + tensorboard_writer = SummaryWriter(args.log_tensorboard) else: _logger.warning( "You've requested to log metrics to tensorboard but package not found. " @@ -785,6 +784,7 @@ def main(): loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + tensorboard_writer=tensorboard_writer, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -857,7 +857,8 @@ def train_one_epoch( amp_autocast=suppress, loss_scaler=None, model_ema=None, - mixup_fn=None + mixup_fn=None, + tensorboard_writer=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: From 04ec4504fb276099b495c60de8c1b72ccd1fe091 Mon Sep 17 00:00:00 2001 From: exx8 Date: Tue, 14 Mar 2023 22:53:24 +0200 Subject: [PATCH 4/7] fix --- train.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 20dd021ce1..fabbabc723 100755 --- a/train.py +++ b/train.py @@ -798,6 +798,7 @@ def main(): validate_loss_fn, args, amp_autocast=amp_autocast, + tensorboard_writer=tensorboard_writer, ) if model_ema is not None and not args.model_ema_force_cpu: @@ -922,8 +923,8 @@ def train_one_epoch( batch_time_m.update(time.time() - end) #write to tensorboard if enabled if should_log_to_tensorboard(args): - writer.add_scalar('train/loss', losses_m.val, num_updates) - writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates) + tensorboard_writer.add_scalar('train/loss', losses_m.val, num_updates) + tensorboard_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) @@ -986,7 +987,9 @@ def validate( args, device=torch.device('cuda'), amp_autocast=suppress, - log_suffix='' + log_suffix='', + tensorboard_writer=None, + ): batch_time_m = utils.AverageMeter() losses_m = utils.AverageMeter() @@ -1037,9 +1040,9 @@ def validate( batch_time_m.update(time.time() - end) end = time.time() if should_log_to_tensorboard(args): - writer.add_scalar('val/loss', losses_m.val, batch_idx) - writer.add_scalar('val/acc1', top1_m.val, batch_idx) - writer.add_scalar('val/acc5', top5_m.val, batch_idx) + tensorboard_writer.add_scalar('val/loss', losses_m.val, batch_idx) + tensorboard_writer.add_scalar('val/acc1', top1_m.val, batch_idx) + tensorboard_writer.add_scalar('val/acc5', top5_m.val, batch_idx) if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( From 5e31e00809c01742b6f8035eb1bed35066b73a4c Mon Sep 17 00:00:00 2001 From: exx8 Date: Tue, 14 Mar 2023 23:16:36 +0200 Subject: [PATCH 5/7] add epochs, as the tb needs them to place them --- train.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index fabbabc723..6cfbf8a61f 100755 --- a/train.py +++ b/train.py @@ -799,6 +799,7 @@ def main(): args, amp_autocast=amp_autocast, tensorboard_writer=tensorboard_writer, + epoch=epoch, ) if model_ema is not None and not args.model_ema_force_cpu: @@ -812,6 +813,8 @@ def main(): args, amp_autocast=amp_autocast, log_suffix=' (EMA)', + tensorboard_writer=tensorboard_writer, + epoch=epoch, ) eval_metrics = ema_eval_metrics @@ -989,6 +992,7 @@ def validate( amp_autocast=suppress, log_suffix='', tensorboard_writer=None, + epoch=None, ): batch_time_m = utils.AverageMeter() @@ -1040,9 +1044,10 @@ def validate( batch_time_m.update(time.time() - end) end = time.time() if should_log_to_tensorboard(args): - tensorboard_writer.add_scalar('val/loss', losses_m.val, batch_idx) - tensorboard_writer.add_scalar('val/acc1', top1_m.val, batch_idx) - tensorboard_writer.add_scalar('val/acc5', top5_m.val, batch_idx) + #by the updates + tensorboard_writer.add_scalar('val/loss', losses_m.val, epoch*last_idx+batch_idx) + tensorboard_writer.add_scalar('val/acc1', top1_m.val, epoch*last_idx+batch_idx) + tensorboard_writer.add_scalar('val/acc5', top5_m.val, epoch*last_idx+batch_idx) if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( From 4c7ba46505bbe04c088c17e5930e16c7d1bf67a4 Mon Sep 17 00:00:00 2001 From: exx8 Date: Tue, 14 Mar 2023 23:20:53 +0200 Subject: [PATCH 6/7] If we call without the epoch value, we can't log to tensorboard (we have no absolute reference to where we are at the training process). --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 6cfbf8a61f..7927fb2e6f 100755 --- a/train.py +++ b/train.py @@ -1043,7 +1043,7 @@ def validate( batch_time_m.update(time.time() - end) end = time.time() - if should_log_to_tensorboard(args): + if should_log_to_tensorboard(args) and epoch is not None: #by the updates tensorboard_writer.add_scalar('val/loss', losses_m.val, epoch*last_idx+batch_idx) tensorboard_writer.add_scalar('val/acc1', top1_m.val, epoch*last_idx+batch_idx) From 2c324050d4ba5f4ee353ef321653e4e8ee524e03 Mon Sep 17 00:00:00 2001 From: exx8 Date: Wed, 15 Mar 2023 01:21:25 +0200 Subject: [PATCH 7/7] add default value --- train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train.py b/train.py index 7927fb2e6f..080c4f3f52 100755 --- a/train.py +++ b/train.py @@ -730,6 +730,7 @@ def main(): _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") + tensorboard_writer = None if should_log_to_tensorboard(args): if has_tensorboard: tensorboard_writer = SummaryWriter(args.log_tensorboard)