Skip to content

Commit 85a3937

Browse files
committed
train: Change argument name and value
1 parent 1b6abaa commit 85a3937

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from test import evaluate
1515

1616
parser = argparse.ArgumentParser()
17-
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
18-
parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
17+
parser.add_argument("--epoch", type=int, default=100, help="number of epoch")
18+
parser.add_argument("--gradient_accumulation", type=int, default=1, help="number of gradient accums before step")
1919
parser.add_argument("--multiscale_training", type=bool, default=True, help="allow for multi-scale training")
2020
parser.add_argument("--batch_size", type=int, default=32, help="size of each image batch")
2121
parser.add_argument("--num_workers", type=int, default=8, help="number of cpu threads to use during batch generation")
@@ -68,7 +68,7 @@
6868
loss_log = tqdm.tqdm(total=0, position=2, bar_format='{desc}', leave=False)
6969

7070
# Train code.
71-
for epoch in tqdm.tqdm(range(args.epochs), desc='Epoch'):
71+
for epoch in tqdm.tqdm(range(args.epoch), desc='Epoch'):
7272
# 모델을 train mode로 설정
7373
model.train()
7474

@@ -85,7 +85,7 @@
8585
loss.backward()
8686

8787
# 기울기 누적 (Accumulate gradient)
88-
if step % args.gradient_accumulations == 0:
88+
if step % args.gradient_accumulation == 0:
8989
optimizer.step()
9090
optimizer.zero_grad()
9191

0 commit comments

Comments
 (0)