Skip to content

Commit 61c0aa8

Browse files
committed
Add eval_batch_size (fix #17)
1 parent 5bcb32f commit 61c0aa8

File tree

6 files changed

+7
-4
lines changed

6 files changed

+7
-4
lines changed

audio/configs/wav2vec2-pretraining.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ dataset:
1515
optimizer:
1616
lr: 0.0005
1717
train:
18-
batch_size: 1
18+
batch_size: 16
19+
eval_batch_size: 16
1920
num_epochs: 1000
2021
log_dir: 'audio/logs'
2122
save_ckpt_freq: 10

audio/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, cfg: DictConfig):
3636
padding='longest')
3737
self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size,
3838
collate_fn=self.data_collator)
39-
self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.val_batch_size,
39+
self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size,
4040
collate_fn=self.data_collator)
4141
# Tensorboard
4242
self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir)

text/configs/roberta-pretraining.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ modality: 'text'
22
device: 'cuda'
33
train:
44
batch_size: 32
5+
eval_batch_size: 32
56
num_epochs: 20
67
checkpoints_dir: 'text/checkpoints/roberta-pretrain'
78
log_dir: 'text/logs/roberta-pretrain'

text/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, cfg: DictConfig):
4444
self.test_dataset = WikiText(cfg, 'test', self.tokenizer)
4545
self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size,
4646
collate_fn=self.train_dataset.collate_fn)
47-
self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.val_batch_size,
47+
self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size,
4848
collate_fn=self.test_dataset.collate_fn)
4949
# Tensorboard
5050
self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir)

vision/configs/beit-pretraining.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dataset:
2525
train:
2626
num_epochs: 800
2727
batch_size: 16
28+
eval_batch_size: 16
2829
shuffle: true
2930
save_ckpt_freq: 20
3031
checkpoints_dir: 'vision/checkpoints/beit-pretrain'

vision/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, cfg):
3131
self.train_dataset = MIMPretrainingDataset(cfg, split='train')
3232
self.test_dataset = MIMPretrainingDataset(cfg, split='test')
3333
self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size, shuffle=cfg.train.shuffle)
34-
self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.batch_size, shuffle=cfg.train.shuffle)
34+
self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size, shuffle=cfg.train.shuffle)
3535

3636
# Tensorboard
3737
self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir)

0 commit comments

Comments
 (0)