diff --git a/train.py b/train.py index f0f73cb..a5633b4 100644 --- a/train.py +++ b/train.py @@ -54,7 +54,7 @@ def calc_loss(loss_dict, results, logger, global_step): def train(net, data_loader, loss_dict, optimizer, scheduler,logger, epoch, metric_dict, use_aux): net.train() - progress_bar = dist_tqdm(train_loader) + progress_bar = dist_tqdm(data_loader) t_data_0 = time.time() for b_idx, data_label in enumerate(progress_bar): t_data_1 = time.time()