Skip to content

Commit 9cace77

Browse files
authored
renamed train_batch_size into batch_size
1 parent f234750 commit 9cace77

27 files changed

+32
-32
lines changed

src/templates/template-common/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
seed: 777
22
data_path: ./
3-
train_batch_size: 32
3+
batch_size: 32
44
eval_batch_size: 32
55
num_workers: 4
66
max_epochs: 20

src/templates/template-text-classification/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def setup_data(config):
6161

6262
dataloader_train = idist.auto_dataloader(
6363
dataset_train,
64-
batch_size=config.train_batch_size,
64+
batch_size=config.batch_size,
6565
num_workers=config.num_workers,
6666
shuffle=True,
6767
drop_last=True,

src/templates/template-text-classification/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_setup_data():
2727
model="bert-base-uncased",
2828
tokenizer_dir="/tmp/tokenizer",
2929
max_length=1,
30-
train_batch_size=1,
30+
batch_size=1,
3131
eval_batch_size=1,
3232
num_workers=1,
3333
)

src/templates/template-vision-classification/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def setup_data(config: Any):
1010
1111
Parameters
1212
----------
13-
config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers`
13+
config: needs to contain `data_path`, `batch_size`, `eval_batch_size`, and `num_workers`
1414
"""
1515
#::: if (it.use_dist) { :::#
1616
local_rank = idist.get_local_rank()
@@ -59,7 +59,7 @@ def setup_data(config: Any):
5959

6060
dataloader_train = idist.auto_dataloader(
6161
dataset_train,
62-
batch_size=config.train_batch_size,
62+
batch_size=config.batch_size,
6363
shuffle=True,
6464
num_workers=config.num_workers,
6565
)

src/templates/template-vision-classification/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def set_up():
2525

2626
@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
2727
def test_setup_data():
28-
config = Namespace(data_path="~/data", train_batch_size=1, eval_batch_size=1, num_workers=0)
28+
config = Namespace(data_path="~/data", batch_size=1, eval_batch_size=1, num_workers=0)
2929
dataloader_train, dataloader_eval = setup_data(config)
3030

3131
assert isinstance(dataloader_train, DataLoader)

src/templates/template-vision-dcgan/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def setup_data(config: Any):
1010
1111
Parameters
1212
----------
13-
config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers`
13+
config: needs to contain `data_path`, `batch_size`, `eval_batch_size`, and `num_workers`
1414
"""
1515
#::: if (it.use_dist) { :::#
1616
local_rank = idist.get_local_rank()
@@ -49,7 +49,7 @@ def setup_data(config: Any):
4949

5050
dataloader_train = idist.auto_dataloader(
5151
dataset_train,
52-
batch_size=config.train_batch_size,
52+
batch_size=config.batch_size,
5353
shuffle=True,
5454
num_workers=config.num_workers,
5555
)

src/templates/template-vision-dcgan/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def run(local_rank: int, config: Any):
3939
device = idist.device()
4040

4141
fixed_noise = torch.randn(
42-
config.train_batch_size // idist.get_world_size(),
42+
config.batch_size // idist.get_world_size(),
4343
config.z_dim,
4444
1,
4545
1,

src/templates/template-vision-dcgan/test_all.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def set_up():
2626

2727
@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
2828
def test_setup_data():
29-
config = Namespace(data_path="~/data", train_batch_size=1, eval_batch_size=1, num_workers=0)
29+
config = Namespace(data_path="~/data", batch_size=1, eval_batch_size=1, num_workers=0)
3030
dataloader_train, dataloader_eval, _ = setup_data(config)
3131

3232
assert isinstance(dataloader_train, DataLoader)
@@ -60,7 +60,7 @@ def test_models():
6060

6161
def test_setup_trainer():
6262
model, optimizer, device, loss_fn, batch = set_up()
63-
config = Namespace(use_amp=False, train_batch_size=2, z_dim=100)
63+
config = Namespace(use_amp=False, batch_size=2, z_dim=100)
6464
trainer = setup_trainer(config, model, model, optimizer, optimizer, loss_fn, device, None)
6565
trainer.run([batch, batch])
6666
assert isinstance(trainer.state.output, dict)

src/templates/template-vision-dcgan/trainers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def setup_trainer(
2121
) -> Union[Engine, DeterministicEngine]:
2222
ws = idist.get_world_size()
2323

24-
real_labels = torch.ones(config.train_batch_size // ws, device=device)
25-
fake_labels = torch.zeros(config.train_batch_size // ws, device=device)
26-
noise = torch.randn(config.train_batch_size // ws, config.z_dim, 1, 1, device=device)
24+
real_labels = torch.ones(config.batch_size // ws, device=device)
25+
fake_labels = torch.zeros(config.batch_size // ws, device=device)
26+
noise = torch.randn(config.batch_size // ws, config.z_dim, 1, 1, device=device)
2727

2828
def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
2929
model_g.train()

src/templates/template-vision-segmentation/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def setup_data(config: Namespace):
124124
dataloader_train = idist.auto_dataloader(
125125
dataset_train,
126126
shuffle=True,
127-
batch_size=config.train_batch_size,
127+
batch_size=config.batch_size,
128128
num_workers=config.num_workers,
129129
drop_last=True,
130130
)

0 commit comments

Comments
 (0)