1
- # Copyright (c) 2022 Mira Geoscience Ltd.
1
+ # Copyright (c) 2023 Mira Geoscience Ltd.
2
2
#
3
3
# This file is part of param-sweeps.
4
4
#
8
8
from __future__ import annotations
9
9
10
10
import argparse
11
+ import importlib
12
+ import inspect
11
13
import itertools
12
14
import json
13
15
import os
14
- import subprocess
15
16
import uuid
16
17
from dataclasses import dataclass
17
18
from inspect import signature
@@ -92,6 +93,10 @@ class SweepDriver:
92
93
93
94
def __init__ (self , params ):
94
95
self .params : SweepParams = params
96
+ self .workspace = params .geoh5
97
+ self .working_directory = os .path .dirname (self .workspace .h5file )
98
+ lookup = self .get_lookup ()
99
+ self .write_files (lookup )
95
100
96
101
@staticmethod
97
102
def uuid_from_params (params : tuple ) -> str :
@@ -104,75 +109,97 @@ def uuid_from_params(params: tuple) -> str:
104
109
"""
105
110
return str (uuid .uuid5 (uuid .NAMESPACE_DNS , str (hash (params ))))
106
111
107
- def run (self , files_only = False ):
108
- """Execute a sweep."""
112
+ def get_lookup (self ):
113
+ """Generate lookup table for sweep trials."""
114
+
115
+ lookup = {}
116
+ sets = self .params .parameter_sets ()
117
+ iterations = list (itertools .product (* sets .values ()))
118
+ for iteration in iterations :
119
+ param_uuid = SweepDriver .uuid_from_params (iteration )
120
+ lookup [param_uuid ] = dict (zip (sets .keys (), iteration ))
121
+ lookup [param_uuid ]["status" ] = "pending"
122
+
123
+ lookup = self .update_lookup (lookup , gather_first = True )
124
+ return lookup
125
+
126
+ def update_lookup (self , lookup : dict , gather_first : bool = False ):
127
+ """Updates lookup with new entries. Ensures any previous runs are incorporated."""
128
+ lookup_path = os .path .join (self .working_directory , "lookup.json" )
129
+ if os .path .exists (lookup_path ) and gather_first : # In case restarting
130
+ with open (lookup_path , encoding = "utf8" ) as file :
131
+ lookup .update (json .load (file ))
132
+
133
+ with open (lookup_path , "w" , encoding = "utf8" ) as file :
134
+ json .dump (lookup , file , indent = 4 )
135
+
136
+ return lookup
137
+
138
+ def write_files (self , lookup ):
139
+ """Write ui.geoh5 and ui.json files for sweep trials."""
109
140
110
141
ifile = InputFile .read_ui_json (self .params .worker_uijson )
111
142
with ifile .data ["geoh5" ].open (mode = "r" ) as workspace :
112
- sets = self .params .parameter_sets ()
113
- iterations = list (itertools .product (* sets .values ()))
114
- print (
115
- f"Running parameter sweep for { len (iterations )} "
116
- f"trials of the { ifile .data ['title' ]} driver."
117
- )
118
143
119
- param_lookup = {}
120
- for count , iteration in enumerate (iterations ):
121
- param_uuid = SweepDriver .uuid_from_params (iteration )
122
- filepath = os .path .join (
123
- os .path .dirname (workspace .h5file ), f"{ param_uuid } .ui.geoh5"
124
- )
125
- param_lookup [param_uuid ] = dict (zip (sets .keys (), iteration ))
144
+ for name , trial in lookup .items ():
126
145
127
- if os .path .exists (filepath ):
128
- print (
129
- f"{ count } : Skipping trial: { param_uuid } . "
130
- f"Already computed and saved to file."
131
- )
146
+ if trial ["status" ] != "pending" :
132
147
continue
133
148
134
- print (
135
- f"{ count } : Running trial: { param_uuid } . "
136
- f"Use lookup.json to map uuid to parameter set."
149
+ filepath = os .path .join (
150
+ os .path .dirname (workspace .h5file ), f"{ name } .ui.geoh5"
137
151
)
138
152
with Workspace (filepath ) as iter_workspace :
139
153
ifile .data .update (
140
- dict (param_lookup [param_uuid ], ** {"geoh5" : iter_workspace })
154
+ dict (
155
+ {key : val for key , val in trial .items () if key != "status" },
156
+ ** {"geoh5" : iter_workspace },
157
+ )
141
158
)
142
159
objects = [v for v in ifile .data .values () if hasattr (v , "uid" )]
143
160
for obj in objects :
144
161
if not isinstance (obj , Data ):
145
162
obj .copy (parent = iter_workspace , copy_children = True )
146
163
147
- update_lookup (param_lookup , workspace )
148
-
149
- ifile .name = f"{ param_uuid } .ui.json"
164
+ ifile .name = f"{ name } .ui.json"
150
165
ifile .path = os .path .dirname (workspace .h5file )
151
166
ifile .write_ui_json ()
167
+ lookup [name ]["status" ] = "written"
152
168
153
- if not files_only :
154
- call_worker_subprocess (ifile )
169
+ _ = self .update_lookup (lookup )
155
170
171
+ def run (self ):
172
+ """Execute a sweep."""
156
173
157
- def call_worker_subprocess (ifile : InputFile ):
158
- """Runs the worker for the sweep parameters contained in 'ifile'."""
159
- subprocess .run (
160
- ["python" , "-m" , ifile .data ["run_command" ], ifile .path_name ],
161
- check = True ,
162
- )
174
+ lookup_path = os .path .join (self .working_directory , "lookup.json" )
175
+ with open (lookup_path , encoding = "utf8" ) as file :
176
+ lookup = json .load (file )
163
177
178
+ for name , trial in lookup .items ():
179
+ ifile = InputFile .read_ui_json (
180
+ os .path .join (self .working_directory , f"{ name } .ui.json" )
181
+ )
182
+ status = trial .pop ("status" )
183
+ if status != "complete" :
184
+ lookup [name ]["status" ] = "processing"
185
+ self .update_lookup (lookup )
186
+ call_worker (ifile )
187
+ lookup [name ]["status" ] = "complete"
188
+ self .update_lookup (lookup )
164
189
165
- def update_lookup (lookup : dict , workspace : Workspace ):
166
- """Updates lookup with new entries. Ensures any previous runs are incorporated."""
167
- lookup_path = os .path .join (os .path .dirname (workspace .h5file ), "lookup.json" )
168
- if os .path .exists (lookup_path ): # In case restarting
169
- with open (lookup_path , encoding = "utf8" ) as file :
170
- lookup .update (json .load (file ))
171
190
172
- with open ( lookup_path , "w" , encoding = "utf8" ) as file :
173
- json . dump ( lookup , file , indent = 4 )
191
+ def call_worker ( ifile : InputFile ) :
192
+ """Runs the worker for the sweep parameters contained in 'ifile'."""
174
193
175
- return lookup
194
+ run_cmd = ifile .data ["run_command" ]
195
+ module = importlib .import_module (run_cmd )
196
+ filt = (
197
+ lambda member : inspect .isclass (member )
198
+ and member .__module__ == run_cmd
199
+ and hasattr (member , "run" )
200
+ )
201
+ driver = inspect .getmembers (module , filt )[0 ][1 ]
202
+ driver .start (ifile .path_name )
176
203
177
204
178
205
def file_validation (filepath ):
@@ -188,14 +215,14 @@ def file_validation(filepath):
188
215
raise OSError (f"File argument { filepath } must have extension 'ui.json'." )
189
216
190
217
191
- def main (file_path , files_only = False ):
218
+ def main (file_path ):
192
219
"""Run the program."""
193
220
194
221
file_validation (file_path )
195
222
print ("Reading parameters and workspace..." )
196
223
input_file = InputFile .read_ui_json (file_path )
197
224
sweep_params = SweepParams .from_input_file (input_file )
198
- SweepDriver (sweep_params ).run (files_only )
225
+ SweepDriver (sweep_params ).run ()
199
226
200
227
201
228
if __name__ == "__main__" :
@@ -206,4 +233,4 @@ def main(file_path, files_only=False):
206
233
parser .add_argument ("file" , help = "File with ui.json format." )
207
234
208
235
args = parser .parse_args ()
209
- main (args .file )
236
+ main (os . path . abspath ( args .file ) )
0 commit comments