Skip to content

Commit 6bcf050

Browse files
huppdAnnikaLau
andauthored
Revise member selection (#54)
- clarify naming convenction, for `member_num` use either `member_id`, `member_ids`, or `member_count` - revise the member selection algorithm, instead of checking random it uses this deterministic approach At each step, every ensemble member is evaluated to determine how many of the remaining members satisfy the given tolerances when compared to it. The ensemble member that results in the highest number of other members passing the tolerances is selected. This process is repeated until either all members are successfully validated using the current tolerances, or a maximum number of attempts is reached. In the latter case, the tolerance factor is increased, and the process continues. --------- Co-authored-by: Annika Lauber <annika.lauber@c2sm.ethz.ch>
1 parent 0cbb108 commit 6bcf050

33 files changed

+646
-704
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ python ../externals/probtest/probtest.py tolerance
191191
These commands will generate a number of files:
192192

193193
- `stats_ref.csv`: contains the post-processed output from the unperturbed reference run
194-
- `stats_{member_num}.csv`: contain the post-processed output from the perturbed reference runs (only needed temporarily to generate the tolerance file)
194+
- `stats_{member_id}.csv`: contain the post-processed output from the perturbed reference runs (only needed temporarily to generate the tolerance file)
195195
- `exp_name_tolerance.csv`: contains tolerance ranges computed from the stats-files
196196

197197
These can then be used to compare against the output of a test binary (usually a

engine/cdo_table.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import xarray as xr
1515

1616
from util import model_output_parser
17-
from util.click_util import CommaSeperatedInts, cli_help
17+
from util.click_util import cli_help
1818
from util.constants import cdo_bins
1919
from util.dataframe_ops import df_from_file_ids
2020
from util.file_system import file_names_from_pattern
@@ -93,10 +93,10 @@ def rel_diff_stats(
9393
help=cli_help["file_id"],
9494
)
9595
@click.option(
96-
"--member-num",
97-
type=CommaSeperatedInts(),
98-
default="10",
99-
help=cli_help["member_num"],
96+
"--member-id",
97+
type=int,
98+
default=1,
99+
help=cli_help["member_id"],
100100
)
101101
@click.option(
102102
"--member-type",
@@ -120,20 +120,17 @@ def rel_diff_stats(
120120
def cdo_table(
121121
model_output_dir,
122122
file_id,
123-
member_num,
123+
member_id: int,
124124
member_type,
125125
perturbed_model_output_dir,
126126
cdo_table_file,
127127
file_specification,
128128
): # pylint: disable=too-many-positional-arguments
129-
# TODO: A single perturbed run provides enough data to make proper statistics.
130-
# refactor cdo_table interface to reflect that
131-
if len(member_num) == 1:
132-
member_num = list(range(1, member_num[0] + 1))
129+
133130
if member_type:
134-
member_id = member_type + "_" + str(member_num[0])
131+
member_id_str = member_type + "_" + str(member_id)
135132
else:
136-
member_id = str(member_num[0])
133+
member_id_str = str(member_id)
137134

138135
file_specification = file_specification[0] # can't store dicts as defaults in click
139136
assert isinstance(file_specification, dict), "must be dict"
@@ -154,7 +151,7 @@ def cdo_table(
154151
continue
155152
ref_files.sort()
156153
perturb_files, err = file_names_from_pattern(
157-
perturbed_model_output_dir.format(member_id=member_id), file_pattern
154+
perturbed_model_output_dir.format(member_id=member_id_str), file_pattern
158155
)
159156
if err > 0:
160157
logger.info(
@@ -168,7 +165,7 @@ def cdo_table(
168165
continue
169166
ref_data = xr.open_dataset(f"{model_output_dir}/{rf}")
170167
perturb_data = xr.open_dataset(
171-
f"{perturbed_model_output_dir.format(member_id=member_id)}/{pf}"
168+
f"{perturbed_model_output_dir.format(member_id=member_id_str)}/{pf}"
172169
)
173170
diff_data = ref_data.copy()
174171
varnames = [

engine/check.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import click
1313

1414
from util.click_util import cli_help
15-
from util.dataframe_ops import compute_div_dataframe, test_stats_file_with_tolerances
15+
from util.dataframe_ops import check_stats_file_with_tolerances, compute_div_dataframe
1616
from util.log_handler import logger
1717

1818

@@ -36,7 +36,7 @@
3636
)
3737
def check(input_file_ref, input_file_cur, tolerance_file_name, factor):
3838

39-
out, err, tol = test_stats_file_with_tolerances(
39+
out, err, tol = check_stats_file_with_tolerances(
4040
tolerance_file_name, input_file_ref, input_file_cur, factor
4141
)
4242

engine/init.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@
4343
help=cli_help["config"],
4444
)
4545
@click.option(
46-
"--member-num",
46+
"--member-ids",
4747
type=CommaSeperatedInts(),
48-
default="10",
49-
help=cli_help["member_num"],
48+
default="1,2,3,4,5,6,7,8,9,10",
49+
help=cli_help["member_ids"],
5050
)
5151
@click.option(
5252
"--member-type",
@@ -89,14 +89,15 @@ def init(
8989
reference,
9090
config,
9191
template_name,
92-
member_num,
92+
member_ids,
9393
member_type,
9494
factor,
9595
perturb_amplitude,
9696
timing_current,
9797
timing_reference,
9898
append_time,
9999
): # pylint: disable=too-many-positional-arguments
100+
100101
template_partition = str(template_name).rpartition("/")
101102
env = Environment(
102103
loader=FileSystemLoader(template_partition[0]), undefined=StrictUndefined
@@ -115,8 +116,8 @@ def init(
115116
logger.warning(warn_template, "file_id", "")
116117
if not reference:
117118
logger.warning(warn_template, "reference", "")
118-
if not member_num:
119-
logger.warning(warn_template, "member_num", member_num)
119+
if not member_ids:
120+
logger.warning(warn_template, "member_ids", member_ids)
120121
if not member_type:
121122
logger.warning(warn_template, "member_type", member_type)
122123
if not factor:
@@ -135,7 +136,7 @@ def init(
135136
render_dict["experiment_name"] = experiment_name
136137
render_dict["codebase_install"] = Path(codebase_install).resolve()
137138
render_dict["reference"] = Path(reference).resolve()
138-
render_dict["member_num"] = member_num
139+
render_dict["member_ids"] = member_ids
139140
render_dict["member_type"] = member_type
140141
render_dict["factor"] = factor
141142
render_dict["perturb_amplitude"] = perturb_amplitude
@@ -148,7 +149,7 @@ def init(
148149
# append file_id via json
149150
json_dict = json.loads(rendered)
150151
json_dict["default"]["file_id"] = file_id
151-
json_dict["default"]["member_num"] = member_num
152+
json_dict["default"]["member_ids"] = member_ids
152153
json_dict["default"]["factor"] = factor
153154
rendered = json.dumps(json_dict, indent=2)
154155
# print file

engine/perturb.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
1818
from util.log_handler import logger
1919
from util.netcdf_io import nc4_get_copy
20-
from util.utils import get_seed_from_member_number, process_member_num
20+
from util.utils import get_seed_from_member_id
2121

2222

2323
def create_perturb_files(in_path, in_files, out_path, copy_all_files=False):
@@ -64,10 +64,10 @@ def perturb_array(array, s, a):
6464
help=cli_help["files"],
6565
)
6666
@click.option(
67-
"--member-num",
67+
"--member-ids",
6868
type=CommaSeperatedInts(),
69-
default="10",
70-
help=cli_help["member_num"],
69+
default="1,2,3,4,5,6,7,8,9,10",
70+
help=cli_help["member_ids"],
7171
)
7272
@click.option(
7373
"--member-type",
@@ -95,33 +95,36 @@ def perturb(
9595
model_input_dir,
9696
perturbed_model_input_dir,
9797
files,
98-
member_num,
98+
member_ids,
9999
member_type,
100100
variable_names,
101101
perturb_amplitude,
102102
copy_all_files,
103103
): # pylint: disable=unused-argument, too-many-positional-arguments
104104

105-
processed_member_num = process_member_num(member_num)
106-
107-
for m_num, m_id in processed_member_num:
105+
for m_id in member_ids:
108106

109107
if member_type:
110-
m_id = member_type + "_" + m_id
108+
m_id_str = member_type + "_" + str(m_id)
109+
else:
110+
m_id_str = str(m_id)
111+
111112
perturbed_model_input_dir_member_id = perturbed_model_input_dir.format(
112-
member_id=m_id
113+
member_id=m_id_str
113114
)
115+
114116
data = create_perturb_files(
115117
model_input_dir,
116118
files,
117119
perturbed_model_input_dir_member_id,
118120
copy_all_files,
119121
)
122+
120123
for d in data:
121124
for vn in variable_names:
122125
d.variables[vn][:] = perturb_array(
123126
d.variables[vn][:],
124-
get_seed_from_member_number(m_num),
127+
get_seed_from_member_id(m_id),
125128
perturb_amplitude,
126129
)
127130
d.close()

engine/run_ensemble.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
1919
from util.log_handler import logger
20-
from util.utils import get_seed_from_member_number, process_member_num
20+
from util.utils import get_seed_from_member_id
2121

2222

2323
def is_float(string):
@@ -114,7 +114,7 @@ def append_job(job, job_list, parallel):
114114
try:
115115
time.sleep(5)
116116
p.wait()
117-
test_job_returncode(p)
117+
check_job_returncode(p)
118118
finally:
119119
p.kill()
120120
else:
@@ -127,15 +127,15 @@ def finalize_jobs(job_list, dry, parallel):
127127
for job in job_list:
128128
job.communicate()
129129
try:
130-
test_job_returncode(job)
130+
check_job_returncode(job)
131131
except subprocess.CalledProcessError as e:
132132
logger.error(e)
133133
last_exception = e
134134
if last_exception:
135135
raise last_exception
136136

137137

138-
def test_job_returncode(job):
138+
def check_job_returncode(job):
139139
"""Test job return code."""
140140
if job.returncode != 0:
141141
raise subprocess.CalledProcessError(returncode=job.returncode, cmd=job.args)
@@ -171,10 +171,10 @@ def test_job_returncode(job):
171171
help=cli_help["submit_command"],
172172
)
173173
@click.option(
174-
"--member-num",
175-
default="10",
174+
"--member-ids",
175+
default="1,2,3,4,5,6,7,8,9,10",
176176
type=CommaSeperatedInts(),
177-
help=cli_help["member_num"],
177+
help=cli_help["member_ids"],
178178
)
179179
@click.option(
180180
"--member-type",
@@ -215,7 +215,7 @@ def run_ensemble(
215215
experiment_name,
216216
perturbed_experiment_name,
217217
submit_command,
218-
member_num,
218+
member_ids,
219219
member_type,
220220
parallel,
221221
dry,
@@ -234,31 +234,33 @@ def run_ensemble(
234234
append_job(job, job_list, parallel)
235235

236236
# run the ensemble
237-
processed_member_num = process_member_num(member_num)
237+
for m_id in member_ids:
238238

239-
for m_num, m_id in processed_member_num:
240-
241-
Path(perturbed_run_dir.format(member_id=m_id)).mkdir(
239+
Path(perturbed_run_dir.format(member_id=str(m_id))).mkdir(
242240
exist_ok=True, parents=True
243241
)
244-
os.chdir(perturbed_run_dir.format(member_id=m_id))
242+
os.chdir(perturbed_run_dir.format(member_id=str(m_id)))
243+
245244
if member_type:
246-
m_id = member_type + "_" + m_id
245+
m_id_str = member_type + "_" + str(m_id)
246+
else:
247+
m_id_str = str(m_id)
248+
247249
runscript = f"{run_dir}/{run_script_name}"
248250

249-
perturbed_run_dir_path = perturbed_run_dir.format(member_id=m_id)
250-
perturbed_run_script_path = perturbed_run_script_name.format(member_id=m_id)
251+
perturbed_run_dir_path = perturbed_run_dir.format(member_id=m_id_str)
252+
perturbed_run_script_path = perturbed_run_script_name.format(member_id=m_id_str)
251253
perturbed_runscript = f"{perturbed_run_dir_path}/{perturbed_run_script_path}"
252254

253255
prepare_perturbed_run_script(
254256
runscript,
255257
perturbed_runscript,
256258
experiment_name,
257-
perturbed_experiment_name.format(member_id=m_id),
259+
perturbed_experiment_name.format(member_id=m_id_str),
258260
lhs,
259261
rhs_new,
260262
rhs_old,
261-
get_seed_from_member_number(m_num),
263+
get_seed_from_member_id(m_id),
262264
)
263265

264266
if not dry:

0 commit comments

Comments
 (0)