|
64 | 64 | has_functorch = True
|
65 | 65 | except ImportError as e:
|
66 | 66 | has_functorch = False
|
67 |
| - |
| 67 | +#test tensorboard install |
| 68 | +try: |
| 69 | + from torch.utils.tensorboard import SummaryWriter |
| 70 | + has_tensorboard = True |
| 71 | +except ImportError as e: |
| 72 | + has_tensorboard = False |
68 | 73 | has_compile = hasattr(torch, 'compile')
|
69 | 74 |
|
70 | 75 |
|
|
347 | 352 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
348 | 353 | group.add_argument('--log-wandb', action='store_true', default=False,
|
349 | 354 | help='log training and validation metrics to wandb')
|
350 |
| - |
351 |
| - |
| 355 | +group.add_argument('--log-tensorboard', default='', type=str, metavar='PATH', |
| 356 | + help='log training and validation metrics to TensorBoard') |
352 | 357 | def _parse_args():
|
353 | 358 | # Do we have a config file to parse?
|
354 | 359 | args_config, remaining = config_parser.parse_known_args()
|
@@ -726,6 +731,16 @@ def main():
|
726 | 731 | "You've requested to log metrics to wandb but package not found. "
|
727 | 732 | "Metrics not being logged to wandb, try `pip install wandb`")
|
728 | 733 |
|
| 734 | + if utils.is_primary(args) and args.log_tensorboard: |
| 735 | + if has_tensorboard: |
| 736 | + writer = SummaryWriter(args.log_tensorboard) |
| 737 | + else: |
| 738 | + _logger.warning( |
| 739 | + "You've requested to log metrics to tensorboard but package not found. " |
| 740 | + "Metrics not being logged to tensorboard, try `pip install tensorboard`") |
| 741 | + |
| 742 | + |
| 743 | + |
729 | 744 | # setup learning rate schedule and starting epoch
|
730 | 745 | updates_per_epoch = len(loader_train)
|
731 | 746 | lr_scheduler, num_epochs = create_scheduler_v2(
|
@@ -809,6 +824,7 @@ def main():
|
809 | 824 | lr=sum(lrs) / len(lrs),
|
810 | 825 | write_header=best_metric is None,
|
811 | 826 | log_wandb=args.log_wandb and has_wandb,
|
| 827 | + |
812 | 828 | )
|
813 | 829 |
|
814 | 830 | if saver is not None:
|
@@ -903,6 +919,10 @@ def train_one_epoch(
|
903 | 919 |
|
904 | 920 | num_updates += 1
|
905 | 921 | batch_time_m.update(time.time() - end)
|
| 922 | + #write to tensorboard if enabled |
| 923 | + if should_log_to_tensorboard(args): |
| 924 | + writer.add_scalar('train/loss', losses_m.val, num_updates) |
| 925 | + writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates) |
906 | 926 | if last_batch or batch_idx % args.log_interval == 0:
|
907 | 927 | lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
908 | 928 | lr = sum(lrl) / len(lrl)
|
@@ -954,6 +974,10 @@ def train_one_epoch(
|
954 | 974 | return OrderedDict([('loss', losses_m.avg)])
|
955 | 975 |
|
956 | 976 |
|
| 977 | +def should_log_to_tensorboard(args): |
| 978 | + return args.log_tensorboard and utils.is_primary(args) and has_tensorboard |
| 979 | + |
| 980 | + |
957 | 981 | def validate(
|
958 | 982 | model,
|
959 | 983 | loader,
|
@@ -1011,6 +1035,10 @@ def validate(
|
1011 | 1035 |
|
1012 | 1036 | batch_time_m.update(time.time() - end)
|
1013 | 1037 | end = time.time()
|
| 1038 | + if should_log_to_tensorboard(args): |
| 1039 | + writer.add_scalar('val/loss', losses_m.val, batch_idx) |
| 1040 | + writer.add_scalar('val/acc1', top1_m.val, batch_idx) |
| 1041 | + writer.add_scalar('val/acc5', top5_m.val, batch_idx) |
1014 | 1042 | if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
|
1015 | 1043 | log_name = 'Test' + log_suffix
|
1016 | 1044 | _logger.info(
|
|
0 commit comments