Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ To work with probtest, a configuration file is needed. This file contains all th

### perturb

Perturbs netcdf files that can be used as input by the model.
Creates an set of netcdf files that can be used as input for a perturbed model ensemble.

### run-ensemble

Expand Down
13 changes: 6 additions & 7 deletions engine/cdo_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from util.dataframe_ops import df_from_file_ids
from util.file_system import file_names_from_pattern
from util.log_handler import logger
from util.utils import prepend_type_to_member_id


def compute_rel_diff(var1, var2):
Expand Down Expand Up @@ -127,11 +128,6 @@ def cdo_table(
file_specification,
): # pylint: disable=too-many-positional-arguments

if member_type:
member_id_str = member_type + "_" + str(member_id)
else:
member_id_str = str(member_id)

file_specification = file_specification[0] # can't store dicts as defaults in click
assert isinstance(file_specification, dict), "must be dict"

Expand All @@ -142,6 +138,7 @@ def cdo_table(

# step 1: compute rel-diff netcdf files
with tempfile.TemporaryDirectory() as tmpdir:
typed_member_id = prepend_type_to_member_id(member_type, member_id)
for _, file_pattern in file_id:
ref_files, err = file_names_from_pattern(model_output_dir, file_pattern)
if err > 0:
Expand All @@ -151,7 +148,8 @@ def cdo_table(
continue
ref_files.sort()
perturb_files, err = file_names_from_pattern(
perturbed_model_output_dir.format(member_id=member_id_str), file_pattern
perturbed_model_output_dir.format(member_id=typed_member_id),
file_pattern,
)
if err > 0:
logger.info(
Expand All @@ -165,7 +163,8 @@ def cdo_table(
continue
ref_data = xr.open_dataset(f"{model_output_dir}/{rf}")
perturb_data = xr.open_dataset(
f"{perturbed_model_output_dir.format(member_id=member_id_str)}/{pf}"
f"{perturbed_model_output_dir.format(member_id=typed_member_id)}"
f"/{pf}"
)
diff_data = ref_data.copy()
varnames = [
Expand Down
6 changes: 2 additions & 4 deletions engine/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import click

from util.click_util import cli_help
from util.dataframe_ops import check_stats_file_with_tolerances, compute_div_dataframe
from util.dataframe_ops import check_stats_file_with_tolerances, compute_division
from util.log_handler import logger


Expand Down Expand Up @@ -40,8 +40,6 @@ def check(input_file_ref, input_file_cur, tolerance_file_name, factor):
tolerance_file_name, input_file_ref, input_file_cur, factor
)

div = compute_div_dataframe(err, tol)

if out:
logger.info("RESULT: check PASSED!")
else:
Expand All @@ -51,6 +49,6 @@ def check(input_file_ref, input_file_cur, tolerance_file_name, factor):
logger.info("\nTolerance")
logger.info(tol)
logger.info("\nError relative to tolerance")
logger.info(div)
logger.info(compute_division(err, tol))

sys.exit(0 if out else 1)
82 changes: 37 additions & 45 deletions engine/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
- Creating copies of specified model files and optionally copying all files from
a source directory to a destination directory.
- Applying perturbations to arrays in NetCDF files based on a specified
amplitude and random seed.
perturbation amplitude and fixed seed.
"""

import os
Expand All @@ -17,35 +17,16 @@
from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.netcdf_io import nc4_get_copy
from util.utils import get_seed_from_member_id
from util.utils import get_seed_from_member_id, prepend_type_to_member_id


def create_perturb_files(in_path, in_files, out_path, copy_all_files=False):
path = os.path.abspath(in_path)
if not os.path.exists(out_path):
logger.info("creating new directory: %s", out_path)
os.makedirs(out_path)
data = [nc4_get_copy(f"{path}/{f}", f"{out_path}/{f}") for f in in_files]

if copy_all_files:
all_files = os.listdir(path)
# disregard the input files which are copied above
other_files = [f for f in all_files if f not in in_files]
# copy all other files
for f in other_files:
shutil.copy(f"{in_path}/{f}", out_path)

return data


def perturb_array(array, s, a):
shape = array.shape
np.random.seed(s)
p = (
np.random.rand(*shape) * 2 - 1
) * a + 1 # *2-1: map to [-1,1), *a: rescale to amplitude, +1 perturb around 1
parray = np.copy(array * p)
return parray
def perturb_array(array, seed, perturb_amplitude):
np.random.seed(seed)
# *2-1: map to [-1,1)
# *perturb_amplitude: rescale to perturbation amplitude
# +1 perturb around 1
perturbation = (np.random.rand(*array.shape) * 2 - 1) * perturb_amplitude + 1
return np.copy(array * perturbation)


@click.command()
Expand Down Expand Up @@ -102,29 +83,40 @@ def perturb(
copy_all_files,
): # pylint: disable=unused-argument, too-many-positional-arguments

for m_id in member_ids:

if member_type:
m_id_str = member_type + "_" + str(m_id)
else:
m_id_str = str(m_id)
# `member_id` is already an input option
for _member_id in member_ids:

perturbed_model_input_dir_member_id = perturbed_model_input_dir.format(
member_id=m_id_str
perturbed_dir = perturbed_model_input_dir.format(
member_id=prepend_type_to_member_id(member_type, _member_id)
)

data = create_perturb_files(
model_input_dir,
files,
perturbed_model_input_dir_member_id,
copy_all_files,
)
model_input_dir_abspath = os.path.abspath(model_input_dir)

# Create directory for perturbed ensemble member
if not os.path.exists(perturbed_dir):
logger.info("creating new directory: %s", perturbed_dir)
os.makedirs(perturbed_dir)

# Add perturbed files to member directory
data = [
nc4_get_copy(f"{model_input_dir_abspath}/{f}", f"{perturbed_dir}/{f}")
for f in files
]

for d in data:
for vn in variable_names:
d.variables[vn][:] = perturb_array(
d.variables[vn][:],
get_seed_from_member_id(m_id),
perturb_amplitude,
array=d.variables[vn][:],
seed=get_seed_from_member_id(_member_id),
perturb_amplitude=perturb_amplitude,
)
d.close()

# Copy rest of the files in `model_input_dir` to the perturbed ensemble
# member dir (`perturbed_dir`)
if copy_all_files:
for f in os.listdir(model_input_dir_abspath):
if (
f not in files
): # files added manually via `files` flag already copied above
shutil.copy(os.path.join(model_input_dir, f), perturbed_dir)
33 changes: 12 additions & 21 deletions engine/run_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,7 @@

from util.click_util import CommaSeperatedInts, CommaSeperatedStrings, cli_help
from util.log_handler import logger
from util.utils import get_seed_from_member_id


def is_float(string):
try:
float(string)
return True
except ValueError:
return False
from util.utils import get_seed_from_member_id, prepend_type_to_member_id


# replace an assignment statement (left=*right* to left=new)
Expand Down Expand Up @@ -234,33 +226,32 @@ def run_ensemble(
append_job(job, job_list, parallel)

# run the ensemble
for m_id in member_ids:
for member_id in member_ids:

Path(perturbed_run_dir.format(member_id=str(m_id))).mkdir(
Path(perturbed_run_dir.format(member_id=str(member_id))).mkdir(
exist_ok=True, parents=True
)
os.chdir(perturbed_run_dir.format(member_id=str(m_id)))

if member_type:
m_id_str = member_type + "_" + str(m_id)
else:
m_id_str = str(m_id)
os.chdir(perturbed_run_dir.format(member_id=str(member_id)))

runscript = f"{run_dir}/{run_script_name}"

perturbed_run_dir_path = perturbed_run_dir.format(member_id=m_id_str)
perturbed_run_script_path = perturbed_run_script_name.format(member_id=m_id_str)
typed_member_id = prepend_type_to_member_id(member_type, member_id)

perturbed_run_dir_path = perturbed_run_dir.format(member_id=typed_member_id)
perturbed_run_script_path = perturbed_run_script_name.format(
member_id=typed_member_id
)
perturbed_runscript = f"{perturbed_run_dir_path}/{perturbed_run_script_path}"

prepare_perturbed_run_script(
runscript,
perturbed_runscript,
experiment_name,
perturbed_experiment_name.format(member_id=m_id_str),
perturbed_experiment_name.format(member_id=typed_member_id),
lhs,
rhs_new,
rhs_old,
get_seed_from_member_id(m_id),
get_seed_from_member_id(member_id),
)

if not dry:
Expand Down
61 changes: 21 additions & 40 deletions engine/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

This command line tool provides functionality for:
- Creating and saving statistics dataframes from specified model output files.
- Verifying that lists of values are monotonically increasing.
- Generating statistics for both ensemble and reference model runs.
"""

Expand All @@ -16,10 +15,7 @@
from util.click_util import CommaSeperatedInts, cli_help
from util.dataframe_ops import df_from_file_ids
from util.log_handler import logger


def monotonically_increasing(li):
return all(x <= y for x, y in zip(li[:-1], li[1:]))
from util.utils import prepend_type_to_member_id


def create_stats_dataframe(input_dir, file_id, stats_file_name, file_specification):
Expand All @@ -33,31 +29,6 @@ def create_stats_dataframe(input_dir, file_id, stats_file_name, file_specificati
return df


def process_member(
member_id,
member_type,
model_output_dir,
perturbed_model_output_dir,
file_id,
stats_file_name,
file_specification,
): # pylint: disable=too-many-positional-arguments
if member_id == 0:
input_dir = model_output_dir
m_id_str = "ref"
else:
m_id_str = str(member_id)
if member_type:
m_id_str = member_type + "_" + m_id_str
input_dir = perturbed_model_output_dir.format(member_id=m_id_str)
create_stats_dataframe(
input_dir,
file_id,
stats_file_name.format(member_id=m_id_str),
file_specification,
)


@click.command()
@click.option(
"--ensemble/--no-ensemble",
Expand Down Expand Up @@ -116,21 +87,31 @@ def stats(

# compute stats for the ensemble and the reference run
if ensemble:
df_args = []

member_ids.append(0)
with Pool() as p:
args = [
for member_id in member_ids:
if member_id == 0:
typed_member_id = "ref"
output_dir = model_output_dir
else:
typed_member_id = prepend_type_to_member_id(member_type, member_id)
output_dir = perturbed_model_output_dir.format(
member_id=typed_member_id
)

df_args.append(
(
m_id,
member_type,
model_output_dir,
perturbed_model_output_dir,
output_dir,
file_id,
stats_file_name,
stats_file_name.format(member_id=typed_member_id),
file_specification,
)
for m_id in member_ids
]
p.starmap(process_member, args)
)

with Pool() as p:
p.starmap(create_stats_dataframe, df_args)

else:
create_stats_dataframe(
model_output_dir,
Expand Down
6 changes: 5 additions & 1 deletion tests/util/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from util.utils import get_seed_from_member_id
from util.utils import get_seed_from_member_id, prepend_type_to_member_id


def test_get_seed_from_member_id_invalid():
Expand Down Expand Up @@ -36,3 +36,7 @@ def test_get_seed_from_member_id_unique_seeds():
"""
seeds = [get_seed_from_member_id(i) for i in range(1, 121)]
assert len(seeds) == len(set(seeds)), "Seeds are not unique!"


def test_prepend_type_to_member_id():
assert prepend_type_to_member_id("double", 3) == "double_3"
2 changes: 1 addition & 1 deletion util/dataframe_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def compute_rel_diff_dataframe(df1, df2):
return out


def compute_div_dataframe(df1, df2):
def compute_division(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame:
# avoid division by 0 and put nan instead
out = df1 / df2.replace({0: np.nan})
# put 0 if numerator is 0 as well
Expand Down
6 changes: 5 additions & 1 deletion util/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This module provides utility functions for list and string operations, as well as
a function to generate seeds based on a member number for probtest.
a function to get fixed seeds based on the ensemble member number.
"""

import re
Expand All @@ -14,6 +14,10 @@ def unique_elements(inlist):
return unique


def prepend_type_to_member_id(member_type, member_id):
return (member_type + "_" + str(member_id)) if member_type else str(member_id)


def first_idx_of(li, el):
return li.index(el)

Expand Down
Loading