File tree Expand file tree Collapse file tree 10 files changed +37
-15
lines changed
template-text-classification
template-vision-classification
template-vision-segmentation Expand file tree Collapse file tree 10 files changed +37
-15
lines changed Original file line number Diff line number Diff line change @@ -2,9 +2,15 @@ def test_save_config():
2
2
with open ("./config.yaml" , "r" ) as f :
3
3
config = OmegaConf .load (f )
4
4
5
- save_config (config , "./" )
5
+ # Add backend to config (similar to setup_config)
6
+ config .backend = None
6
7
7
- with open ( "./config-lock.yaml" , "r" ) as f :
8
- test_config = OmegaConf . load ( f )
8
+ with tempfile . TemporaryDirectory ( ) as output_dir :
9
+ output_dir = Path ( output_dir )
9
10
10
- assert config == test_config
11
+ save_config (config , output_dir )
12
+
13
+ with open (output_dir / "config-lock.yaml" , "r" ) as f :
14
+ test_config = OmegaConf .load (f )
15
+
16
+ assert config == test_config
Original file line number Diff line number Diff line change @@ -149,14 +149,14 @@ def resume_from(
149
149
150
150
def setup_output_dir (config : Any , rank : int ) -> Path :
151
151
"""Create output folder."""
152
+ output_dir = config .output_dir
152
153
if rank == 0 :
153
154
now = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
154
155
name = f"{ now } -backend-{ config .backend } -lr-{ config .lr } "
155
156
path = Path (config .output_dir , name )
156
157
path .mkdir (parents = True , exist_ok = True )
157
- config .output_dir = path .as_posix ()
158
-
159
- return Path (idist .broadcast (config .output_dir , src = 0 ))
158
+ output_dir = path .as_posix ()
159
+ return Path (idist .broadcast (output_dir , src = 0 ))
160
160
161
161
162
162
def save_config (config , output_dir ):
Original file line number Diff line number Diff line change @@ -27,9 +27,11 @@ def run(local_rank: int, config: Any):
27
27
manual_seed (config .seed + rank )
28
28
29
29
# create output folder and copy config file to output dir
30
- config . output_dir = setup_output_dir (config , rank )
30
+ output_dir = setup_output_dir (config , rank )
31
31
if rank == 0 :
32
- save_config (config , config .output_dir )
32
+ save_config (config , output_dir )
33
+
34
+ config .output_dir = output_dir
33
35
34
36
# donwload datasets and create dataloaders
35
37
dataloader_train , dataloader_eval = setup_data (config )
Original file line number Diff line number Diff line change 1
1
import os
2
+ import tempfile
2
3
from argparse import Namespace
4
+ from pathlib import Path
3
5
from typing import Iterable
4
6
5
7
import ignite .distributed as idist
Original file line number Diff line number Diff line change @@ -24,9 +24,11 @@ def run(local_rank: int, config: Any):
24
24
manual_seed (config .seed + rank )
25
25
26
26
# create output folder and copy config file to output dir
27
- config . output_dir = setup_output_dir (config , rank )
27
+ output_dir = setup_output_dir (config , rank )
28
28
if rank == 0 :
29
- save_config (config , config .output_dir )
29
+ save_config (config , output_dir )
30
+
31
+ config .output_dir = output_dir
30
32
31
33
# donwload datasets and create dataloaders
32
34
dataloader_train , dataloader_eval = setup_data (config )
Original file line number Diff line number Diff line change 1
1
import os
2
+ import tempfile
2
3
from argparse import Namespace
4
+ from pathlib import Path
3
5
from typing import Iterable
4
6
5
7
import ignite .distributed as idist
Original file line number Diff line number Diff line change @@ -28,9 +28,11 @@ def run(local_rank: int, config: Any):
28
28
manual_seed (config .seed + rank )
29
29
30
30
# create output folder and copy config file to output dir
31
- config . output_dir = setup_output_dir (config , rank )
31
+ output_dir = setup_output_dir (config , rank )
32
32
if rank == 0 :
33
- save_config (config , config .output_dir )
33
+ save_config (config , output_dir )
34
+
35
+ config .output_dir = output_dir
34
36
35
37
# donwload datasets and create dataloaders
36
38
dataloader_train , dataloader_eval , num_channels = setup_data (config )
Original file line number Diff line number Diff line change 1
1
import os
2
+ import tempfile
2
3
from argparse import Namespace
4
+ from pathlib import Path
3
5
from typing import Iterable
4
6
5
7
import ignite .distributed as idist
Original file line number Diff line number Diff line change @@ -34,9 +34,11 @@ def run(local_rank: int, config: Any):
34
34
manual_seed (config .seed + rank )
35
35
36
36
# create output folder and copy config file to output dir
37
- config . output_dir = setup_output_dir (config , rank )
37
+ output_dir = setup_output_dir (config , rank )
38
38
if rank == 0 :
39
- save_config (config , config .output_dir )
39
+ save_config (config , output_dir )
40
+
41
+ config .output_dir = output_dir
40
42
41
43
# donwload datasets and create dataloaders
42
44
dataloader_train , dataloader_eval = setup_data (config )
Original file line number Diff line number Diff line change 1
1
import os
2
+ import tempfile
2
3
from argparse import Namespace
4
+ from pathlib import Path
3
5
4
6
import pytest
5
7
from data import setup_data
You can’t perform that action at this time.
0 commit comments