Skip to content

Commit 50e35d1

Browse files
committed
Save checkpoint at each epoch
1 parent 9c1523b commit 50e35d1

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

main_miniimagenet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def generate_samples(images, model, args):
6565

6666
def main(args):
6767
writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
68-
save_filename = './models/{0}/model.pt'.format(args.output_folder)
68+
save_filename = './models/{0}'.format(args.output_folder)
6969

7070
transform = transforms.Compose([
7171
transforms.RandomResizedCrop(128),
@@ -114,8 +114,10 @@ def main(args):
114114

115115
if (epoch == 0) or (loss < best_loss):
116116
best_loss = loss
117-
with open(save_filename, 'wb') as f:
117+
with open('{0}/best.pt'.format(save_filename), 'wb') as f:
118118
torch.save(model.state_dict(), f)
119+
with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f:
120+
torch.save(model.state_dict(), f)
119121

120122
if __name__ == '__main__':
121123
import argparse

0 commit comments

Comments
 (0)