Skip to content

Commit df76bf1

Browse files
authored
Fix the output_dir logic and bugs for reproducibility (#307)
* Changed the function setup_output_dir and output_dir logic in templates * add coment again to setup_output_dir * Modify test_save_config for the new behaviour * Add temp directory for test_save_config
1 parent f41ea18 commit df76bf1

File tree

10 files changed

+37
-15
lines changed

10 files changed

+37
-15
lines changed

src/templates/template-common/test_all.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@ def test_save_config():
22
with open("./config.yaml", "r") as f:
33
config = OmegaConf.load(f)
44

5-
save_config(config, "./")
5+
# Add backend to config (similar to setup_config)
6+
config.backend = None
67

7-
with open("./config-lock.yaml", "r") as f:
8-
test_config = OmegaConf.load(f)
8+
with tempfile.TemporaryDirectory() as output_dir:
9+
output_dir = Path(output_dir)
910

10-
assert config == test_config
11+
save_config(config, output_dir)
12+
13+
with open(output_dir / "config-lock.yaml", "r") as f:
14+
test_config = OmegaConf.load(f)
15+
16+
assert config == test_config

src/templates/template-common/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,14 @@ def resume_from(
149149

150150
def setup_output_dir(config: Any, rank: int) -> Path:
151151
"""Create output folder."""
152+
output_dir = config.output_dir
152153
if rank == 0:
153154
now = datetime.now().strftime("%Y%m%d-%H%M%S")
154155
name = f"{now}-backend-{config.backend}-lr-{config.lr}"
155156
path = Path(config.output_dir, name)
156157
path.mkdir(parents=True, exist_ok=True)
157-
config.output_dir = path.as_posix()
158-
159-
return Path(idist.broadcast(config.output_dir, src=0))
158+
output_dir = path.as_posix()
159+
return Path(idist.broadcast(output_dir, src=0))
160160

161161

162162
def save_config(config, output_dir):

src/templates/template-text-classification/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ def run(local_rank: int, config: Any):
2727
manual_seed(config.seed + rank)
2828

2929
# create output folder and copy config file to output dir
30-
config.output_dir = setup_output_dir(config, rank)
30+
output_dir = setup_output_dir(config, rank)
3131
if rank == 0:
32-
save_config(config, config.output_dir)
32+
save_config(config, output_dir)
33+
34+
config.output_dir = output_dir
3335

3436
# donwload datasets and create dataloaders
3537
dataloader_train, dataloader_eval = setup_data(config)

src/templates/template-text-classification/test_all.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
2+
import tempfile
23
from argparse import Namespace
4+
from pathlib import Path
35
from typing import Iterable
46

57
import ignite.distributed as idist

src/templates/template-vision-classification/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ def run(local_rank: int, config: Any):
2424
manual_seed(config.seed + rank)
2525

2626
# create output folder and copy config file to output dir
27-
config.output_dir = setup_output_dir(config, rank)
27+
output_dir = setup_output_dir(config, rank)
2828
if rank == 0:
29-
save_config(config, config.output_dir)
29+
save_config(config, output_dir)
30+
31+
config.output_dir = output_dir
3032

3133
# donwload datasets and create dataloaders
3234
dataloader_train, dataloader_eval = setup_data(config)

src/templates/template-vision-classification/test_all.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
2+
import tempfile
23
from argparse import Namespace
4+
from pathlib import Path
35
from typing import Iterable
46

57
import ignite.distributed as idist

src/templates/template-vision-dcgan/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ def run(local_rank: int, config: Any):
2828
manual_seed(config.seed + rank)
2929

3030
# create output folder and copy config file to output dir
31-
config.output_dir = setup_output_dir(config, rank)
31+
output_dir = setup_output_dir(config, rank)
3232
if rank == 0:
33-
save_config(config, config.output_dir)
33+
save_config(config, output_dir)
34+
35+
config.output_dir = output_dir
3436

3537
# donwload datasets and create dataloaders
3638
dataloader_train, dataloader_eval, num_channels = setup_data(config)

src/templates/template-vision-dcgan/test_all.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
2+
import tempfile
23
from argparse import Namespace
4+
from pathlib import Path
35
from typing import Iterable
46

57
import ignite.distributed as idist

src/templates/template-vision-segmentation/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ def run(local_rank: int, config: Any):
3434
manual_seed(config.seed + rank)
3535

3636
# create output folder and copy config file to output dir
37-
config.output_dir = setup_output_dir(config, rank)
37+
output_dir = setup_output_dir(config, rank)
3838
if rank == 0:
39-
save_config(config, config.output_dir)
39+
save_config(config, output_dir)
40+
41+
config.output_dir = output_dir
4042

4143
# donwload datasets and create dataloaders
4244
dataloader_train, dataloader_eval = setup_data(config)

src/templates/template-vision-segmentation/test_all.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
2+
import tempfile
23
from argparse import Namespace
4+
from pathlib import Path
35

46
import pytest
57
from data import setup_data

0 commit comments

Comments
 (0)