@@ -84,7 +84,7 @@ def is_equal(config1, config2):
8484
8585
8686@pytest .fixture
87- def models_dir ():
87+ def shared_working_dir ():
8888 with tempfile .TemporaryDirectory (prefix = "models" ) as temp_dir :
8989 yield temp_dir
9090
@@ -108,9 +108,11 @@ def streams_dir():
108108
109109
110110@pytest .fixture
111- def private_conf (models_dir ):
111+ def private_conf (shared_working_dir ):
112112 cf = OmegaConf .create (DUMMY_PRIVATE_CONF )
113- cf .model_path = models_dir
113+ cf .path_shared_working_dir = shared_working_dir
114+ cf .path_shared_slurm_dir = shared_working_dir
115+
114116 return cf
115117
116118
@@ -157,6 +159,12 @@ def test_contains_private(config_fresh):
157159 assert contains_keys (config_fresh , sanitized_private_conf )
158160
159161
162+ def test_is_paths_set (config_fresh ):
163+ paths = {"model_path" : "foo" , "run_path" : "bar" }
164+
165+ assert contains_keys (config_fresh , paths )
166+
167+
160168@pytest .mark .parametrize ("overwrite_dict" , DUMMY_OVERWRITES , indirect = True )
161169def test_load_with_overwrite_dict (overwrite_dict , private_config_file ):
162170 cf = config .load_config (private_config_file , None , None , overwrite_dict )
@@ -179,6 +187,15 @@ def test_load_with_overwrite_file(private_config_file, overwrite_file):
179187 assert contains (cf , sub_cf )
180188
181189
190+ def test_load_with_stream_in_overwrite (private_config_file , streams_dir , mocker ):
191+ overwrite = {"streams_directory" : streams_dir }
192+ stub = mocker .patch ("weathergen.utils.config.load_streams" , return_value = streams_dir )
193+
194+ config .load_config (private_config_file , None , None , overwrite )
195+
196+ stub .assert_called_once_with (streams_dir )
197+
198+
182199def test_load_multiple_overwrites (private_config_file ):
183200 overwrites = [{"foo" : 1 , "bar" : 1 , "baz" : 1 }, {"foo" : 2 , "bar" : 2 }, {"foo" : 3 }]
184201
0 commit comments