Skip to content

Commit 41aa69e

Browse files
author
Xiao Wang
committed
finish pre-train section
1 parent d84ccd3 commit 41aa69e

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pretrain/main_worker.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,16 @@ def main_worker(gpu, ngpus_per_node,args):
193193
log_writer=log_writer,
194194
args=args
195195
)
196-
val_stats = {**{f'valid_{k}': v for k, v in val_stats.items()},
197-
'epoch': epoch,}
196+
val_loss = val_stats['loss']
197+
log_stats_val = {**{f'val_{k}': v for k, v in val_stats.items()},
198+
'epoch': epoch,}
198199
if is_main_process():
199-
write_log(log_dir,"valid",val_stats)
200+
write_log(log_dir,"val",log_stats_val)
200201
if epoch%save_freq==0 or epoch==epochs-1:
201202
#output_dir, args,epoch, model_without_ddp, optimizer, loss_scaler
202203
save_checkpoint(model_dir, args,epoch, model_without_ddp, optimizer, loss_scaler)
203204

204-
val_loss = val_stats['loss']
205+
205206
if val_loss < best_loss:
206207
best_loss = val_loss
207208
#model_path,args,epoch, model_without_ddp, optimizer, loss_scaler

0 commit comments

Comments
 (0)