Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ To work with probtest, a configuration file is needed. This file contains all th

### perturb

Perturbs netcdf files that can be used as input by the model.
Creates an ensemble of netcdf files that can be used as input by the model.

### run-ensemble

Expand Down
12 changes: 5 additions & 7 deletions engine/cdo_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from util.dataframe_ops import df_from_file_ids
from util.file_system import file_names_from_pattern
from util.log_handler import logger
from util.utils import prepend_type_to_member_id


def compute_rel_diff(var1, var2):
Expand Down Expand Up @@ -126,11 +127,7 @@ def cdo_table(
cdo_table_file,
file_specification,
): # pylint: disable=too-many-positional-arguments

if member_type:
member_id_str = member_type + "_" + str(member_id)
else:
member_id_str = str(member_id)


file_specification = file_specification[0] # can't store dicts as defaults in click
assert isinstance(file_specification, dict), "must be dict"
Expand All @@ -142,6 +139,7 @@ def cdo_table(

# step 1: compute rel-diff netcdf files
with tempfile.TemporaryDirectory() as tmpdir:
complete_member_id = prepend_type_to_member_id(member_type, member_id)
for _, file_pattern in file_id:
ref_files, err = file_names_from_pattern(model_output_dir, file_pattern)
if err > 0:
Expand All @@ -151,7 +149,7 @@ def cdo_table(
continue
ref_files.sort()
perturb_files, err = file_names_from_pattern(
perturbed_model_output_dir.format(member_id=member_id_str), file_pattern
perturbed_model_output_dir.format(member_id=complete_member_id), file_pattern
)
if err > 0:
logger.info(
Expand All @@ -165,7 +163,7 @@ def cdo_table(
continue
ref_data = xr.open_dataset(f"{model_output_dir}/{rf}")
perturb_data = xr.open_dataset(
f"{perturbed_model_output_dir.format(member_id=member_id_str)}/{pf}"
f"{perturbed_model_output_dir.format(member_id=complete_member_id)}/{pf}"
)
diff_data = ref_data.copy()
varnames = [
Expand Down
8 changes: 4 additions & 4 deletions engine/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import click

from util.click_util import cli_help
from util.dataframe_ops import check_stats_file_with_tolerances, compute_div_dataframe
from util.dataframe_ops import check_stats_file_with_tolerances, compute_division_df
from util.log_handler import logger


Expand Down Expand Up @@ -40,8 +40,6 @@ def check(input_file_ref, input_file_cur, tolerance_file_name, factor):
tolerance_file_name, input_file_ref, input_file_cur, factor
)

div = compute_div_dataframe(err, tol)

if out:
logger.info("RESULT: check PASSED!")
else:
Expand All @@ -51,6 +49,8 @@ def check(input_file_ref, input_file_cur, tolerance_file_name, factor):
logger.info("\nTolerance")
logger.info(tol)
logger.info("\nError relative to tolerance")
logger.info(div)
logger.info(
compute_division_df(err, tol)
)

sys.exit(0 if out else 1)
80 changes: 34 additions & 46 deletions engine/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
- Creating copies of specified model files and optionally copying all files from
a source directory to a destination directory.
- Applying perturbations to arrays in NetCDF files based on a specified
amplitude and random seed.
perturbation amplitude and random seed.
"""

import os
Expand All @@ -17,35 +17,15 @@
from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.netcdf_io import nc4_get_copy
from util.utils import get_seed_from_member_id
from util.utils import get_seed_from_member_id, prepend_type_to_member_id


def create_perturb_files(in_path, in_files, out_path, copy_all_files=False):
path = os.path.abspath(in_path)
if not os.path.exists(out_path):
logger.info("creating new directory: %s", out_path)
os.makedirs(out_path)
data = [nc4_get_copy(f"{path}/{f}", f"{out_path}/{f}") for f in in_files]

if copy_all_files:
all_files = os.listdir(path)
# disregard the input files which are copied above
other_files = [f for f in all_files if f not in in_files]
# copy all other files
for f in other_files:
shutil.copy(f"{in_path}/{f}", out_path)

return data


def perturb_array(array, s, a):
shape = array.shape
np.random.seed(s)
p = (
np.random.rand(*shape) * 2 - 1
) * a + 1 # *2-1: map to [-1,1), *a: rescale to amplitude, +1 perturb around 1
parray = np.copy(array * p)
return parray
def perturb_array(array, seed, perturb_amplitude):
np.random.seed(seed)
perturbation = (
np.random.rand(*array.shape) * 2 - 1
) * perturb_amplitude + 1 # *2-1: map to [-1,1), *perturb_amplitude: rescale to perturbation amplitude, +1 perturb around 1
return np.copy(array * perturbation)


@click.command()
Expand Down Expand Up @@ -102,29 +82,37 @@ def perturb(
copy_all_files,
): # pylint: disable=unused-argument, too-many-positional-arguments

for m_id in member_ids:
for member_id in member_ids:

if member_type:
m_id_str = member_type + "_" + str(m_id)
else:
m_id_str = str(m_id)

perturbed_model_input_dir_member_id = perturbed_model_input_dir.format(
member_id=m_id_str
)

data = create_perturb_files(
model_input_dir,
files,
perturbed_model_input_dir_member_id,
copy_all_files,
perturbed_dir = perturbed_model_input_dir.format(
member_id=prepend_type_to_member_id(member_type, member_id)
)

model_input_dir_abspath = os.path.abspath(model_input_dir)

# Create directory for perturbed ensemble member
if not os.path.exists(perturbed_dir):
logger.info("creating new directory: %s", perturbed_dir)
os.makedirs(perturbed_dir)

# Add perturbed files to member directory
data = [
nc4_get_copy(
f"{model_input_dir_abspath}/{f}", f"{perturbed_dir}/{f}"
) for f in files
]

for d in data:
for vn in variable_names:
d.variables[vn][:] = perturb_array(
d.variables[vn][:],
get_seed_from_member_id(m_id),
perturb_amplitude,
array=d.variables[vn][:],
seed=get_seed_from_member_id(member_id),
perturb_amplitude=perturb_amplitude,
)
d.close()

# Copy rest of the files in `model_input_dir` to the perturbed ensemble member dir (`perturbed_dir`)
if copy_all_files:
for f in os.listdir(model_input_dir_abspath):
if f not in files: # files added manually via `files` flag already copied above
shutil.copy(os.path.join(model_input_dir, f), perturbed_dir)
31 changes: 10 additions & 21 deletions engine/run_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,7 @@

from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.utils import get_seed_from_member_id


def is_float(string):
try:
float(string)
return True
except ValueError:
return False
from util.utils import get_seed_from_member_id, prepend_type_to_member_id


# replace an assignment statement (left=*right* to left=new)
Expand Down Expand Up @@ -234,33 +226,30 @@ def run_ensemble(
append_job(job, job_list, parallel)

# run the ensemble
for m_id in member_ids:
for member_id in member_ids:

Path(perturbed_run_dir.format(member_id=str(m_id))).mkdir(
Path(perturbed_run_dir.format(member_id=str(member_id))).mkdir(
exist_ok=True, parents=True
)
os.chdir(perturbed_run_dir.format(member_id=str(m_id)))

if member_type:
m_id_str = member_type + "_" + str(m_id)
else:
m_id_str = str(m_id)
os.chdir(perturbed_run_dir.format(member_id=str(member_id)))

runscript = f"{run_dir}/{run_script_name}"

complete_member_id = prepend_type_to_member_id(member_type, member_id)

perturbed_run_dir_path = perturbed_run_dir.format(member_id=m_id_str)
perturbed_run_script_path = perturbed_run_script_name.format(member_id=m_id_str)
perturbed_run_dir_path = perturbed_run_dir.format(member_id=complete_member_id)
perturbed_run_script_path = perturbed_run_script_name.format(member_id=complete_member_id)
perturbed_runscript = f"{perturbed_run_dir_path}/{perturbed_run_script_path}"

prepare_perturbed_run_script(
runscript,
perturbed_runscript,
experiment_name,
perturbed_experiment_name.format(member_id=m_id_str),
perturbed_experiment_name.format(member_id=complete_member_id),
lhs,
rhs_new,
rhs_old,
get_seed_from_member_id(m_id),
get_seed_from_member_id(member_id),
)

if not dry:
Expand Down
61 changes: 20 additions & 41 deletions engine/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

This command line tool provides functionality for:
- Creating and saving statistics dataframes from specified model output files.
- Verifying that lists of values are monotonically increasing.
- Generating statistics for both ensemble and reference model runs.
"""

Expand All @@ -16,10 +15,7 @@
from util.click_util import CommaSeperatedInts, cli_help
from util.dataframe_ops import df_from_file_ids
from util.log_handler import logger


def monotonically_increasing(li):
return all(x <= y for x, y in zip(li[:-1], li[1:]))
from util.utils import prepend_type_to_member_id


def create_stats_dataframe(input_dir, file_id, stats_file_name, file_specification):
Expand All @@ -33,31 +29,6 @@ def create_stats_dataframe(input_dir, file_id, stats_file_name, file_specificati
return df


def process_member(
member_id,
member_type,
model_output_dir,
perturbed_model_output_dir,
file_id,
stats_file_name,
file_specification,
): # pylint: disable=too-many-positional-arguments
if member_id == 0:
input_dir = model_output_dir
m_id_str = "ref"
else:
m_id_str = str(member_id)
if member_type:
m_id_str = member_type + "_" + m_id_str
input_dir = perturbed_model_output_dir.format(member_id=m_id_str)
create_stats_dataframe(
input_dir,
file_id,
stats_file_name.format(member_id=m_id_str),
file_specification,
)


@click.command()
@click.option(
"--ensemble/--no-ensemble",
Expand Down Expand Up @@ -116,21 +87,29 @@ def stats(

# compute stats for the ensemble and the reference run
if ensemble:
df_args = []

member_ids.append(0)
with Pool() as p:
args = [
for member_id in member_ids:
if member_id == 0:
complete_member_id = "ref"
output_dir = model_output_dir
else:
complete_member_id = prepend_type_to_member_id(member_type, member_id)
output_dir = perturbed_model_output_dir.format(member_id=complete_member_id)

df_args.append(
(
m_id,
member_type,
model_output_dir,
perturbed_model_output_dir,
output_dir,
file_id,
stats_file_name,
file_specification,
stats_file_name.format(member_id=complete_member_id),
file_specification
)
for m_id in member_ids
]
p.starmap(process_member, args)
)

with Pool() as p:
p.starmap(create_stats_dataframe, df_args)

else:
create_stats_dataframe(
model_output_dir,
Expand Down
5 changes: 4 additions & 1 deletion tests/util/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from util.utils import get_seed_from_member_id
from util.utils import get_seed_from_member_id, prepend_type_to_member_id


def test_get_seed_from_member_id_invalid():
Expand Down Expand Up @@ -36,3 +36,6 @@ def test_get_seed_from_member_id_unique_seeds():
"""
seeds = [get_seed_from_member_id(i) for i in range(1, 121)]
assert len(seeds) == len(set(seeds)), "Seeds are not unique!"

def test_prepend_type_to_member_id():
assert prepend_type_to_member_id("double", 3) == "double_3"
2 changes: 1 addition & 1 deletion util/dataframe_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def compute_rel_diff_dataframe(df1, df2):
return out


def compute_div_dataframe(df1, df2):
def compute_division_df(df1, df2):
# avoid division by 0 and put nan instead
out = df1 / df2.replace({0: np.nan})
# put 0 if numerator is 0 as well
Expand Down
3 changes: 3 additions & 0 deletions util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def unique_elements(inlist):
unique.append(element)
return unique

def prepend_type_to_member_id(member_type, member_id):
return (member_type + "_" + str(member_id)) if member_type else str(member_id)


def first_idx_of(li, el):
return li.index(el)
Expand Down
Loading