Skip to content

Commit 6c2a47d

Browse files
authored
Merge pull request #15 from MiraGeoscience/release/0.1.4
Release/0.1.4
2 parents 76833fd + 6313796 commit 6c2a47d

File tree

12 files changed

+456
-393
lines changed

12 files changed

+456
-393
lines changed

.pre-commit-config.yaml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ ci:
1010

1111
repos:
1212
- repo: https://github.com/psf/black
13-
rev: 22.8.0
13+
rev: 22.12.0
1414
hooks:
1515
- id: black
1616
types: [text]
@@ -22,7 +22,7 @@ repos:
2222
types: [text]
2323
types_or: [python, pyi]
2424
- repo: https://github.com/PyCQA/isort
25-
rev: 5.10.1
25+
rev: 5.11.4
2626
hooks:
2727
- id: isort
2828
additional_dependencies: [toml] # to read config from pyproject.toml
@@ -35,20 +35,20 @@ repos:
3535
types: [text]
3636
types_or: [python, pyi]
3737
- repo: https://github.com/PyCQA/flake8
38-
rev: 5.0.4
38+
rev: 6.0.0
3939
hooks:
4040
- id: flake8
4141
types: [text]
4242
types_or: [python, pyi]
4343
- repo: https://github.com/asottile/pyupgrade
44-
rev: v2.37.3
44+
rev: v3.3.1
4545
hooks:
4646
- id: pyupgrade
4747
args: [--py37-plus]
4848
types: [text]
4949
types_or: [python, pyi]
5050
- repo: https://github.com/pre-commit/mirrors-mypy
51-
rev: v0.971
51+
rev: v0.991
5252
hooks:
5353
- id: mypy
5454
additional_dependencies: [types-toml]
@@ -72,14 +72,14 @@ repos:
7272
types_or: [python, pyi]
7373
exclude: ^(devtools/|docs/|setup.py)
7474
- repo: https://github.com/codespell-project/codespell
75-
rev: v2.2.1
75+
rev: v2.2.2
7676
hooks:
7777
- id: codespell
7878
exclude: (\.ipynb$|^\.github/workflows/issue_to_jira.yml$)
7979
types: [text]
8080
types_or: [python, pyi]
8181
- repo: https://github.com/pre-commit/pre-commit-hooks
82-
rev: v4.3.0
82+
rev: v4.4.0
8383
hooks:
8484
- id: trailing-whitespace
8585
exclude: \.mdj$
@@ -95,12 +95,12 @@ repos:
9595
- id: mixed-line-ending
9696
- id: name-tests-test
9797
- repo: https://github.com/rstcheck/rstcheck
98-
rev: v6.1.0
98+
rev: v6.1.1
9999
hooks:
100100
- id: rstcheck
101101
additional_dependencies: [sphinx]
102102
- repo: https://github.com/pre-commit/pygrep-hooks
103-
rev: v1.9.0
103+
rev: v1.10.0
104104
hooks:
105105
- id: rst-backticks
106106
exclude: ^THIRD_PARTY_SOFTWARE.rst$

devtools/check-copyright.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright (c) 2022 Mira Geoscience Ltd.
3+
# Copyright (c) 2023 Mira Geoscience Ltd.
44
#
55
# This file is part of param-sweeps.
66
#

param_sweeps/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
# Copyright (c) 2022 Mira Geoscience Ltd.
1+
# Copyright (c) 2023 Mira Geoscience Ltd.
22
#
33
# This file is part of param-sweeps.
44
#
55
# param-sweeps is distributed under the terms and conditions of the MIT License
66
# (see LICENSE file at the root of this source code package).
77

8-
__version__ = "0.1.3"
8+
9+
__version__ = "0.1.4"

param_sweeps/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022 Mira Geoscience Ltd.
1+
# Copyright (c) 2023 Mira Geoscience Ltd.
22
#
33
# This file is part of param-sweeps.
44
#

param_sweeps/driver.py

Lines changed: 76 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022 Mira Geoscience Ltd.
1+
# Copyright (c) 2023 Mira Geoscience Ltd.
22
#
33
# This file is part of param-sweeps.
44
#
@@ -8,10 +8,11 @@
88
from __future__ import annotations
99

1010
import argparse
11+
import importlib
12+
import inspect
1113
import itertools
1214
import json
1315
import os
14-
import subprocess
1516
import uuid
1617
from dataclasses import dataclass
1718
from inspect import signature
@@ -92,6 +93,10 @@ class SweepDriver:
9293

9394
def __init__(self, params):
9495
self.params: SweepParams = params
96+
self.workspace = params.geoh5
97+
self.working_directory = os.path.dirname(self.workspace.h5file)
98+
lookup = self.get_lookup()
99+
self.write_files(lookup)
95100

96101
@staticmethod
97102
def uuid_from_params(params: tuple) -> str:
@@ -104,75 +109,97 @@ def uuid_from_params(params: tuple) -> str:
104109
"""
105110
return str(uuid.uuid5(uuid.NAMESPACE_DNS, str(hash(params))))
106111

107-
def run(self, files_only=False):
108-
"""Execute a sweep."""
112+
def get_lookup(self):
113+
"""Generate lookup table for sweep trials."""
114+
115+
lookup = {}
116+
sets = self.params.parameter_sets()
117+
iterations = list(itertools.product(*sets.values()))
118+
for iteration in iterations:
119+
param_uuid = SweepDriver.uuid_from_params(iteration)
120+
lookup[param_uuid] = dict(zip(sets.keys(), iteration))
121+
lookup[param_uuid]["status"] = "pending"
122+
123+
lookup = self.update_lookup(lookup, gather_first=True)
124+
return lookup
125+
126+
def update_lookup(self, lookup: dict, gather_first: bool = False):
127+
"""Updates lookup with new entries. Ensures any previous runs are incorporated."""
128+
lookup_path = os.path.join(self.working_directory, "lookup.json")
129+
if os.path.exists(lookup_path) and gather_first: # In case restarting
130+
with open(lookup_path, encoding="utf8") as file:
131+
lookup.update(json.load(file))
132+
133+
with open(lookup_path, "w", encoding="utf8") as file:
134+
json.dump(lookup, file, indent=4)
135+
136+
return lookup
137+
138+
def write_files(self, lookup):
139+
"""Write ui.geoh5 and ui.json files for sweep trials."""
109140

110141
ifile = InputFile.read_ui_json(self.params.worker_uijson)
111142
with ifile.data["geoh5"].open(mode="r") as workspace:
112-
sets = self.params.parameter_sets()
113-
iterations = list(itertools.product(*sets.values()))
114-
print(
115-
f"Running parameter sweep for {len(iterations)} "
116-
f"trials of the {ifile.data['title']} driver."
117-
)
118143

119-
param_lookup = {}
120-
for count, iteration in enumerate(iterations):
121-
param_uuid = SweepDriver.uuid_from_params(iteration)
122-
filepath = os.path.join(
123-
os.path.dirname(workspace.h5file), f"{param_uuid}.ui.geoh5"
124-
)
125-
param_lookup[param_uuid] = dict(zip(sets.keys(), iteration))
144+
for name, trial in lookup.items():
126145

127-
if os.path.exists(filepath):
128-
print(
129-
f"{count}: Skipping trial: {param_uuid}. "
130-
f"Already computed and saved to file."
131-
)
146+
if trial["status"] != "pending":
132147
continue
133148

134-
print(
135-
f"{count}: Running trial: {param_uuid}. "
136-
f"Use lookup.json to map uuid to parameter set."
149+
filepath = os.path.join(
150+
os.path.dirname(workspace.h5file), f"{name}.ui.geoh5"
137151
)
138152
with Workspace(filepath) as iter_workspace:
139153
ifile.data.update(
140-
dict(param_lookup[param_uuid], **{"geoh5": iter_workspace})
154+
dict(
155+
{key: val for key, val in trial.items() if key != "status"},
156+
**{"geoh5": iter_workspace},
157+
)
141158
)
142159
objects = [v for v in ifile.data.values() if hasattr(v, "uid")]
143160
for obj in objects:
144161
if not isinstance(obj, Data):
145162
obj.copy(parent=iter_workspace, copy_children=True)
146163

147-
update_lookup(param_lookup, workspace)
148-
149-
ifile.name = f"{param_uuid}.ui.json"
164+
ifile.name = f"{name}.ui.json"
150165
ifile.path = os.path.dirname(workspace.h5file)
151166
ifile.write_ui_json()
167+
lookup[name]["status"] = "written"
152168

153-
if not files_only:
154-
call_worker_subprocess(ifile)
169+
_ = self.update_lookup(lookup)
155170

171+
def run(self):
172+
"""Execute a sweep."""
156173

157-
def call_worker_subprocess(ifile: InputFile):
158-
"""Runs the worker for the sweep parameters contained in 'ifile'."""
159-
subprocess.run(
160-
["python", "-m", ifile.data["run_command"], ifile.path_name],
161-
check=True,
162-
)
174+
lookup_path = os.path.join(self.working_directory, "lookup.json")
175+
with open(lookup_path, encoding="utf8") as file:
176+
lookup = json.load(file)
163177

178+
for name, trial in lookup.items():
179+
ifile = InputFile.read_ui_json(
180+
os.path.join(self.working_directory, f"{name}.ui.json")
181+
)
182+
status = trial.pop("status")
183+
if status != "complete":
184+
lookup[name]["status"] = "processing"
185+
self.update_lookup(lookup)
186+
call_worker(ifile)
187+
lookup[name]["status"] = "complete"
188+
self.update_lookup(lookup)
164189

165-
def update_lookup(lookup: dict, workspace: Workspace):
166-
"""Updates lookup with new entries. Ensures any previous runs are incorporated."""
167-
lookup_path = os.path.join(os.path.dirname(workspace.h5file), "lookup.json")
168-
if os.path.exists(lookup_path): # In case restarting
169-
with open(lookup_path, encoding="utf8") as file:
170-
lookup.update(json.load(file))
171190

172-
with open(lookup_path, "w", encoding="utf8") as file:
173-
json.dump(lookup, file, indent=4)
191+
def call_worker(ifile: InputFile):
192+
"""Runs the worker for the sweep parameters contained in 'ifile'."""
174193

175-
return lookup
194+
run_cmd = ifile.data["run_command"]
195+
module = importlib.import_module(run_cmd)
196+
filt = (
197+
lambda member: inspect.isclass(member)
198+
and member.__module__ == run_cmd
199+
and hasattr(member, "run")
200+
)
201+
driver = inspect.getmembers(module, filt)[0][1]
202+
driver.start(ifile.path_name)
176203

177204

178205
def file_validation(filepath):
@@ -188,14 +215,14 @@ def file_validation(filepath):
188215
raise OSError(f"File argument {filepath} must have extension 'ui.json'.")
189216

190217

191-
def main(file_path, files_only=False):
218+
def main(file_path):
192219
"""Run the program."""
193220

194221
file_validation(file_path)
195222
print("Reading parameters and workspace...")
196223
input_file = InputFile.read_ui_json(file_path)
197224
sweep_params = SweepParams.from_input_file(input_file)
198-
SweepDriver(sweep_params).run(files_only)
225+
SweepDriver(sweep_params).run()
199226

200227

201228
if __name__ == "__main__":
@@ -206,4 +233,4 @@ def main(file_path, files_only=False):
206233
parser.add_argument("file", help="File with ui.json format.")
207234

208235
args = parser.parse_args()
209-
main(args.file)
236+
main(os.path.abspath(args.file))

param_sweeps/generate.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022 Mira Geoscience Ltd.
1+
# Copyright (c) 2023 Mira Geoscience Ltd.
22
#
33
# This file is part of param-sweeps.
44
#
@@ -9,14 +9,19 @@
99

1010
import argparse
1111
import os
12+
import re
1213
from copy import deepcopy
1314

1415
from geoh5py.ui_json import InputFile
1516

1617
from param_sweeps.constants import default_ui_json
1718

1819

19-
def generate(worker: str, parameters: list[str] = None, update_values: dict = None):
20+
def generate(
21+
worker: str,
22+
parameters: list[str] | None = None,
23+
update_values: dict | None = None,
24+
):
2025
"""
2126
Generate an *_sweep.ui.json file to sweep parameters of the driver associated with 'file'.
2227
@@ -47,7 +52,8 @@ def generate(worker: str, parameters: list[str] = None, update_values: dict = No
4752
dirname = os.path.dirname(file)
4853
filename = os.path.basename(file)
4954
filename = filename.rstrip("ui.json")
50-
filename = filename.rstrip("_sweep")
55+
filename = re.sub(r"\._sweep$", "", filename)
56+
# filename = filename.rstrip("_sweep")
5157
filename = f"{filename}_sweep.ui.json"
5258

5359
print(f"Writing sweep file to: {os.path.join(dirname, filename)}")

param_sweeps/sample_driver.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) 2023 Mira Geoscience Ltd.
2+
#
3+
# This file is part of param-sweeps.
4+
#
5+
# param-sweeps is distributed under the terms and conditions of the MIT License
6+
# (see LICENSE file at the root of this source code package).
7+
8+
from __future__ import annotations
9+
10+
from dataclasses import dataclass
11+
12+
from geoh5py.ui_json import InputFile
13+
14+
15+
@dataclass
16+
class SampleParams:
17+
18+
data_object: str | None = None
19+
data: str | None = None
20+
param: int = 1
21+
22+
def __init__(self, input_file):
23+
for key, value in input_file.data.items():
24+
setattr(self, key, value)
25+
26+
27+
class SampleDriver:
28+
def __init__(self, params):
29+
self.params = params
30+
31+
def run(self):
32+
print(self.params.param)
33+
34+
@classmethod
35+
def start(cls, filepath, driver_class=None):
36+
_ = driver_class
37+
ifile = InputFile.read_ui_json(filepath)
38+
params = SampleParams(ifile)
39+
SampleDriver(params).run()

0 commit comments

Comments
 (0)