From 94a0ce05aa444f15dae0a650c9e2f6e48bceee99 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Thu, 23 May 2024 19:32:19 -0700 Subject: [PATCH 01/22] Adds initial parametric functions --- cf_xarray/accessor.py | 47 +++---------- cf_xarray/parametric.py | 109 +++++++++++++++++++++++++++++++ cf_xarray/tests/test_accessor.py | 54 +++------------ 3 files changed, 128 insertions(+), 82 deletions(-) create mode 100644 cf_xarray/parametric.py diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index a10c4886..c704d3e2 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -25,7 +25,7 @@ from xarray.core.rolling import Coarsen, Rolling from xarray.core.weighted import Weighted -from . import sgrid +from . import parametric, sgrid from .criteria import ( _DSG_ROLES, cf_role_criteria, @@ -2668,13 +2668,8 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None): """ ds = self._obj - requirements = { - "ocean_s_coordinate_g1": {"depth_c", "depth", "s", "C", "eta"}, - "ocean_s_coordinate_g2": {"depth_c", "depth", "s", "C", "eta"}, - "ocean_sigma_coordinate": {"sigma", "eta", "depth"}, - } - allterms = self.formula_terms + for dim in allterms: if prefix is None: assert ( @@ -2696,6 +2691,7 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None): suffix = dim.split("_") zname = f"{prefix}_" + "_".join(suffix[1:]) + # never touched, if standard name is missing it's not included in allterms if "standard_name" not in ds[dim].attrs: continue stdname = ds[dim].attrs["standard_name"] @@ -2704,46 +2700,21 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None): terms = {} for key, value in allterms[dim].items(): if value not in ds: + # is this ever hit, if variable is missing it's missing in decoded allterms raise KeyError( f"Variable {value!r} is required to decode coordinate for {dim!r}" " but it is absent in the Dataset." ) terms[key] = ds[value] - absent_terms = requirements[stdname] - set(terms) - if absent_terms: - raise KeyError(f"Required terms {absent_terms} absent in dataset.") - - if stdname == "ocean_s_coordinate_g1": - # S(k,j,i) = depth_c * s(k) + (depth(j,i) - depth_c) * C(k) - S = ( - terms["depth_c"] * terms["s"] - + (terms["depth"] - terms["depth_c"]) * terms["C"] - ) + func = parametric.get_parametric_func(stdname) - # z(n,k,j,i) = S(k,j,i) + eta(n,j,i) * (1 + S(k,j,i) / depth(j,i)) - ztemp = S + terms["eta"] * (1 + S / terms["depth"]) + absent_terms = func._requirements - set(terms) - elif stdname == "ocean_s_coordinate_g2": - # make sure all necessary terms are present in terms - # (depth_c * s(k) + depth(j,i) * C(k)) / (depth_c + depth(j,i)) - S = (terms["depth_c"] * terms["s"] + terms["depth"] * terms["C"]) / ( - terms["depth_c"] + terms["depth"] - ) - - # z(n,k,j,i) = eta(n,j,i) + (eta(n,j,i) + depth(j,i)) * S(k,j,i) - ztemp = terms["eta"] + (terms["eta"] + terms["depth"]) * S - - elif stdname == "ocean_sigma_coordinate": - # z(n,k,j,i) = eta(n,j,i) + sigma(k)*(depth(j,i)+eta(n,j,i)) - ztemp = terms["eta"] + terms["sigma"] * (terms["depth"] + terms["eta"]) - - else: - raise NotImplementedError( - f"Coordinate function for {stdname!r} not implemented yet. Contributions welcome!" - ) + if absent_terms: + raise KeyError(f"Required terms {absent_terms} absent in dataset.") - ds.coords[zname] = ztemp + ds.coords[zname] = func(**terms) @xr.register_dataarray_accessor("cf") diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py new file mode 100644 index 00000000..83e4bc8c --- /dev/null +++ b/cf_xarray/parametric.py @@ -0,0 +1,109 @@ +import numpy as np +import inspect + +_REGISTRY = {} + +def register(name=None): + def wrapper(func): + func_name = name or func.__name__ + + arg_spec = inspect.getfullargspec(func) + + func._requirements = set(arg_spec.args) + + if func_name not in _REGISTRY: + _REGISTRY[func_name] = func + return wrapper + +def get_parametric_func(stdname): + try: + return _REGISTRY[stdname] + except KeyError: + raise NotImplementedError( + f"Coordinate function for {stdname!r} not implemented yet. Contributions welcome!" + ) + +@register() +def atmosphere_ln_pressure_coordinate(p0, lev): + return p0 * np.exp(-lev) + +@register() +def atmosphere_sigma_coordinate(sigma, ps, ptop): + return ptop + sigma * (ps - ptop) + +@register() +def atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, a=None, ap=None): + if a is None: + value = ap + b * ps + else: + value = a * p + b * ps + + return value + +@register() +def atmosphere_hybrid_height_coordinate(a, b, orog): + return a + b * orog + +@register() +def atmosphere_sleve_coordinate(a, b1, b2, ztop, zsurf1, zsurf2): + return a + ztop + b1 * zsurf1 + b2 * zsurf2 + +@register() +def ocean_sigma_coordinate(sigma, eta, depth): + return eta + sigma * (depth + eta) + +@register() +def ocean_s_coordinate(s, eta, depth, a, b, depth_c): + c = (1 - b) * np.sinh(a * s) / np.sinh(a) + b * (np.tanh(a * (s + 0.5)) / 2 * np.tanh(0.5 * a) - 0.5) + + return eta * (1 + s) + depth_c * s + (depth - depth_c) * c + +@register() +def ocean_s_coordinate_g1(s, C, eta, depth, depth_c): + s = depth_c * s + (depth - depth_c) * C + + return s + eta * (1 + s / depth) + +@register() +def ocean_s_coordinate_g2(s, C, eta, depth, depth_c): + s = (depth_c * s + depth * C) / (depth_c + depth) + + return eta + (eta + depth) * s + +@register() +def ocean_sigma_z_coordinate(sigma, eta, depth, depth_c, nsigma, zlev): + n, j, i = eta.shape + + k = sigma.shape[0] + + z = np.zeros((n, k, j, i)) + + sigma_defined = ~np.isnan(sigma) + + zlev_defined = ~np.isnan(zlev) + + depth_min = np.minimum(depth_c, depth[np.newaxis, :, :]) + + z[:, sigma_defined, :, :] = eta[:, np.newaxis, :, :] + sigma[sigma_defined, np.newaxis, np.newaxis] * (depth_min + eta[:, np.newaxis, :, :]) + + z[:, zlev_defined, :, :] = zlev[zlev_defined] + + return z + +@register() +def ocean_double_sigma_coordinate(sigma, depth, z1, z2, a, href, k_c): + k = sigma.shape[0] + + j, i = depth.shape + + z = np.zeros((k, j, i)) + + f = 0.5 * (z1 + z2) + 0.5 * (z1 - z2) * np.tanh(2 * a / (z1 - z2) * (depth - href)) + + above_kc = sigma.k > k_c + + z[above_kc, :, :] = f + (sigma[above_kc] - 1) * (depth[np.newaxis, :, :] - f) + + z[~above_kc, :, :] = sigma[~above_kc] * f + + return z diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 60a68114..4f03f616 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -31,7 +31,6 @@ forecast, mollwds, multiple, - pomds, popds, romsds, rotds, @@ -1304,52 +1303,19 @@ def test_Z_vs_vertical_ROMS() -> None: ) -def test_param_vcoord_ocean_s_coord() -> None: - romsds.s_rho.attrs["standard_name"] = "ocean_s_coordinate_g2" - Zo_rho = (romsds.hc * romsds.s_rho + romsds.Cs_r * romsds.h) / ( - romsds.hc + romsds.h - ) - expected = romsds.zeta + (romsds.zeta + romsds.h) * Zo_rho - romsds.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"}) - assert_allclose( - romsds.z_rho.reset_coords(drop=True), expected.reset_coords(drop=True) - ) - - romsds.s_rho.attrs["standard_name"] = "ocean_s_coordinate_g1" - Zo_rho = romsds.hc * (romsds.s_rho - romsds.Cs_r) + romsds.Cs_r * romsds.h - - expected = Zo_rho + romsds.zeta * (1 + Zo_rho / romsds.h) - romsds.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"}) - assert_allclose( - romsds.z_rho.reset_coords(drop=True), expected.reset_coords(drop=True) - ) - - romsds.cf.decode_vertical_coords(outnames={"s_rho": "ZZZ_rho"}) - assert "ZZZ_rho" in romsds.coords - - copy = romsds.copy(deep=False) - del copy["zeta"] - with pytest.raises(KeyError): - copy.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"}) - - copy = romsds.copy(deep=False) - copy.s_rho.attrs["formula_terms"] = "s: s_rho C: Cs_r depth: h depth_c: hc" - with pytest.raises(KeyError): - copy.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"}) - +def test_decode_vertical_coords() -> None: + with pytest.raises( + AssertionError, match="if prefix is None, outnames must be provided" + ): + romsds.cf.decode_vertical_coords() -def test_param_vcoord_ocean_sigma_coordinate() -> None: - expected = pomds.zeta + pomds.sigma * (pomds.depth + pomds.zeta) - pomds.cf.decode_vertical_coords(outnames={"sigma": "z"}) - assert_allclose(pomds.z.reset_coords(drop=True), expected.reset_coords(drop=True)) + with pytest.warns(DeprecationWarning): + romsds.cf.decode_vertical_coords(prefix="z_rho") - copy = pomds.copy(deep=False) - del copy["zeta"] - with pytest.raises(AssertionError): - copy.cf.decode_vertical_coords() + romsds_less_h = romsds.drop_vars(["h"]) - with pytest.raises(KeyError): - copy.cf.decode_vertical_coords(outnames={}) + with pytest.raises(KeyError, match="Required terms {'depth'} absent in dataset."): + romsds_less_h.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"}) def test_formula_terms() -> None: From 4bb4c02ca773b0adf9ad3826d5c6afb899632dd8 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Fri, 7 Jun 2024 09:58:46 -0700 Subject: [PATCH 02/22] Updates parametric module and adds tests --- cf_xarray/accessor.py | 16 +- cf_xarray/parametric.py | 621 ++++++++++++++++++++++++++--- cf_xarray/tests/test_parametric.py | 485 ++++++++++++++++++++++ 3 files changed, 1064 insertions(+), 58 deletions(-) create mode 100644 cf_xarray/tests/test_parametric.py diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index c704d3e2..83f2cc77 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -2707,12 +2707,16 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None): ) terms[key] = ds[value] - func = parametric.get_parametric_func(stdname) - - absent_terms = func._requirements - set(terms) - - if absent_terms: - raise KeyError(f"Required terms {absent_terms} absent in dataset.") + try: + func = parametric.func_from_stdname(stdname) + except AttributeError: + # Should occur since stdname is check before + raise NotImplementedError( + f"Coordinate function for {stdname!r} not implmented yet. Contributions welcome!" + ) from None + + # let KeyError propagate + parametric.check_requirements(func, terms) ds.coords[zname] = func(**terms) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 83e4bc8c..f9e128fd 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -1,109 +1,626 @@ -import numpy as np import inspect +import sys + +import numpy as np +import xarray as xr + +ocean_stdname_map = { + "altitude": { + "zlev": "altitude", + "eta": "sea_surface_height_above_geoid", + "depth": "sea_floor_depth_below_geoid", + }, + "height_above_geopotential_datum": { + "zlev": "height_above_geopotential_datum", + "eta": "sea_surface_height_above_ geopotential_datum", + "depth": "sea_floor_depth_below_ geopotential_datum", + }, + "height_above_reference_ellipsoid": { + "zlev": "height_above_reference_ellipsoid", + "eta": "sea_surface_height_above_ reference_ellipsoid", + "depth": "sea_floor_depth_below_ reference_ellipsoid", + }, + "height_above_mean_sea_level": { + "zlev": "height_above_mean_sea_level", + "eta": "sea_surface_height_above_mean_ sea_level", + "depth": "sea_floor_depth_below_mean_ sea_level", + }, +} + + +def _derive_ocean_stdname(**kwargs): + """Derive standard name for computer ocean coordinates. + + Uses the concatentation of formula terms `zlev`, `eta`, and `depth` + standard names to compare against formula term and standard names + from a table. This can occur with any combination e.g. `zlev`, or + `zlev` + `depth`. If a match is found the standard name for the + computed value is returned. + + Parameters + ---------- + zlev : dict + Attributes for `zlev` variable. + eta : dict + Attributes for `eta` variable. + depth : dict + Attributes for `depth` variable. + + Returns + ------- + str + Standard name for the computer value. + + Raises + ------ + ValueError + If `kwargs` is empty, missing values for `kwargs` keys, or could not derive the standard name. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#table-computed-standard-names + """ + + found_stdname = None + + allowed_names = {"zlev", "eta", "depth"} + + if len(kwargs) == 0 or not (set(kwargs) <= allowed_names): + raise ValueError( + f"Must provide atleast one of {', '.join(sorted(allowed_names))}." + ) + + search_term = "" + + for x, y in sorted(kwargs.items(), key=lambda x: x[0]): + try: + search_term = f"{search_term}{y['standard_name']}" + except TypeError: + raise ValueError( + f"The values for {', '.join(sorted(kwargs.keys()))} cannot be `None`." + ) from None + except KeyError: + raise ValueError( + f"The standard name for the {x!r} variable is not available." + ) from None + + for x, y in ocean_stdname_map.items(): + check_term = "".join( + [ + y[i] + for i, j in sorted(kwargs.items(), key=lambda x: x[0]) + if j is not None + ] + ) + + if search_term == check_term: + found_stdname = x + + break + + if found_stdname is None: + stdnames = ", ".join( + [y["standard_name"] for _, y in sorted(kwargs.items(), key=lambda x: x[0])] + ) + + raise ValueError( + f"Could not derive standard name from combination of {stdnames}." + ) -_REGISTRY = {} + return found_stdname -def register(name=None): - def wrapper(func): - func_name = name or func.__name__ - arg_spec = inspect.getfullargspec(func) +def check_requirements(func, terms): + """Checks terms against function requirements. - func._requirements = set(arg_spec.args) + Uses `func` argument specification as requirements and checks terms against this. + Postitional arguments without a default are required but when a default value is + provided the arguement is considered optional. Atleast one optional argument must + be present (special case for atmosphere_hybrid_sigma_pressure_coordinate). - if func_name not in _REGISTRY: - _REGISTRY[func_name] = func - return wrapper + Parameters + ---------- + func : function + Function to check requirements. + terms : list + List of terms to check `func` requirements against. + + Raises + ------ + KeyError + If `terms` is empty or missing required/optional terms. + """ + if not isinstance(terms, set): + terms = set(terms) + + spec = inspect.getfullargspec(func) + + args = spec.args or [] + + if len(terms) == 0: + raise KeyError(f"Required terms {', '.join(sorted(args))} absent in dataset.") + + n = len(spec.defaults or []) + + # last `n` arguments are optional + opt = set(args[len(args) - n :]) + + req = set(args) - opt + + # req must all be present in terms + if (req & terms) != req: + req_diff = sorted(req - terms) + + opt_err = "" + + # optional is not present in terms + if len(opt) > 0 and not (opt <= terms): + opt_err = f"and atleast one optional term {', '.join(sorted(opt))} " + + raise KeyError( + f"Required terms {', '.join(req_diff)} {opt_err}absent in dataset." + ) -def get_parametric_func(stdname): - try: - return _REGISTRY[stdname] - except KeyError: - raise NotImplementedError( - f"Coordinate function for {stdname!r} not implemented yet. Contributions welcome!" + # atleast one optional is in diff, only required for atmoshphere hybrid sigma pressure coordinate + if len(opt) > 0 and not (opt <= terms): + raise KeyError( + f"Atleast one of the optional terms {', '.join(sorted(opt))} is absent in dataset." ) -@register() + +def func_from_stdname(stdname): + """Get function from module. + + Uses `stdname` to return function from module. + + Parameters + ---------- + stdname : str + Name of the function. + + Raises + ------ + AttributeError + If a function name `stdname` is not in the module. + """ + m = sys.modules[__name__] + + return getattr(m, stdname) + + def atmosphere_ln_pressure_coordinate(p0, lev): - return p0 * np.exp(-lev) + """Atmosphere natural log pressure coordinate. + + Standard name: atmosphere_ln_pressure_coordinate + + Parameters + ---------- + p0 : xr.DataArray + Reference pressure. + lev : xr.DataArray + Vertical dimensionless coordinate. + + Returns + ------- + xr.DataArray + A DataArray with new pressure coordinate. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#atmosphere-natural-log-pressure-coordinate + """ + p = p0 * np.exp(-lev) + + p = p.squeeze().rename("p").assign_attrs(standard_name="air_pressure") + + return p + -@register() def atmosphere_sigma_coordinate(sigma, ps, ptop): - return ptop + sigma * (ps - ptop) + """Atmosphere sigma coordinate. + + Standard name: atmosphere_sigma_coordinate + + Parameters + ---------- + sigma : xr.DataArray + Vertical dimensionless coordinate. + ps : xr.DataArray + Horizontal surface pressure. + + Returns + ------- + xr.DataArray + A DataArray with new pressure coordinate. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_atmosphere_sigma_coordinate + """ + p = ptop + sigma * (ps - ptop) + + p = p.squeeze().rename("p").assign_attrs(standard_name="air_pressure") + + return p.transpose("time", "lev", "lat", "lon") + -@register() def atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, a=None, ap=None): + """Atmosphere hybrid sigma pressure coordinate. + + Standard name: atmosphere_hybrid_sigma_pressure_coordinate + + Parameters + ---------- + b : xr.DataArray + Component of hybrid coordinate. + ps : xr.DataArray + Horizontal surface pressure. + p0 : xr.DataArray + Reference pressure. + a : xr.DataArray + Component of hybrid coordinate. + ap : xr.DataArray + Component of hybrid coordinate. + + Returns + ------- + xr.DataArray + A DataArray with new pressure coordinate. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_atmosphere_hybrid_sigma_pressure_coordinate + """ if a is None: - value = ap + b * ps + p = ap + b * ps else: - value = a * p + b * ps + p = a * p0 + b * ps + + p = p.squeeze().rename("p").assign_attrs(standard_name="air_pressure") + + return p.transpose("time", "lev", "lat", "lon") - return value -@register() def atmosphere_hybrid_height_coordinate(a, b, orog): - return a + b * orog + """Atmosphere hybrid height coordinate. + + Standard name: atmosphere_hybrid_height_coordinate + + Parameters + ---------- + a : xr.DataArray + Height. + b : xr.DataArray + Dimensionless. + orog : xr.DataArray + Height of the surface above the datum. + + Returns + ------- + xr.DataArray + A DataArray with the height above the datum. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#atmosphere-hybrid-height-coordinate + """ + z = a + b * orog + + orog_stdname = orog.attrs["standard_name"] + + if orog_stdname == "surface_altitude": + out_stdname = "altitude" + elif orog_stdname == "surface_height_above_geopotential_datum": + out_stdname = "height_above_geopotential_datum" + + z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + + return z.transpose("time", "lev", "lat", "lon") + -@register() def atmosphere_sleve_coordinate(a, b1, b2, ztop, zsurf1, zsurf2): - return a + ztop + b1 * zsurf1 + b2 * zsurf2 + """Atmosphere smooth level vertical (SLEVE) coordinate. + + Standard name: atmosphere_sleve_coordinate + + Parameters + ---------- + a : xr.DataArray + Dimensionless coordinate whcih defines hybrid level. + b1 : xr.DataArray + Dimensionless coordinate whcih defines hybrid level. + b2 : xr.DataArray + Dimensionless coordinate whcih defines hybrid level. + ztop : xr.DataArray + Height above the top of the model above datum. + zsurf1 : xr.DataArray + Large-scale component of the topography. + zsurf2 : xr.DataArray + Small-scale component of the topography. + + Returns + ------- + xr.DataArray + A DataArray with the height above the datum. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_atmosphere_smooth_level_vertical_sleve_coordinate + """ + z = a * ztop + b1 * zsurf1 + b2 * zsurf2 + + ztop_stdname = ztop.attrs["standard_name"] + + if ztop_stdname == "altitude_at_top_of_atmosphere_model": + out_stdname = "altitude" + elif ztop_stdname == "height_above_geopotential_datum_at_top_of_atmosphere_model": + out_stdname = "height_above_geopotential_datum" + + z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + + return z.transpose("time", "lev", "lat", "lon") + -@register() def ocean_sigma_coordinate(sigma, eta, depth): - return eta + sigma * (depth + eta) + """Ocean sigma coordinate. + + Standard name: ocean_sigma_coordinate + + Parameters + ---------- + sigma : xr.DataArray + Vertical dimensionless coordinate. + eta : xr.DataArray + Height of the sea surface (positive upwards) relative to the datum. + depth : xr.DataArray + Distance (positive value) from the datum to the sea floor. + + Returns + ------- + xr.DataArray + A DataArray with the height (positive upwards) relative to the datum. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_sigma_coordinate + """ + z = eta + sigma * (depth + eta) + + out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) + + z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + + return z.transpose("time", "lev", "lat", "lon") + -@register() def ocean_s_coordinate(s, eta, depth, a, b, depth_c): - c = (1 - b) * np.sinh(a * s) / np.sinh(a) + b * (np.tanh(a * (s + 0.5)) / 2 * np.tanh(0.5 * a) - 0.5) + """Ocean s-coordinate. + + Standard name: ocean_s_coordinate + + Parameters + ---------- + s : xr.DataArray + Dimensionless coordinate. + eta : xr.DataArray + Height of the sea surface (positive upwards) relative to the datum. + depth : xr.DataArray + Distance (positive value) from the datum to the sea floor. + a : xr.DataArray + Constant controlling stretch. + b : xr.DataArray + Constant controlling stretch. + depth_c : xr.DataArray + Constant controlling stretch. + + Returns + ------- + xr.DataArray + A DataArray with the height (positive upwards) relative to the datum. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate + """ + C = (1 - b) * np.sinh(a * s) / np.sinh(a) + b * ( + np.tanh(a * (s + 0.5)) / 2 * np.tanh(0.5 * a) - 0.5 + ) + + z = eta * (1 + s) + depth_c * s + (depth - depth_c) * C + + out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) + + z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + + return z.transpose("time", "lev", "lat", "lon") - return eta * (1 + s) + depth_c * s + (depth - depth_c) * c -@register() def ocean_s_coordinate_g1(s, C, eta, depth, depth_c): + """Ocean s-coordinate, generic form 1. + + Standard name: ocean_s_coordinate_g1 + + Parameters + ---------- + s : xr.DataArray + Dimensionless coordinate. + C : xr.DataArray + Dimensionless vertical coordinate stretching function. + eta : xr.DataArray + Height of the ocean surface (positive upwards) relative to the ocean datum. + depth : xr.DataArray + Distance from ocean datum to sea floor (positive value). + depth_c : xr.DataArray + Constant (positive value) is a critical depth controlling the stretching. + + Returns + ------- + xr.DataArray + A DataArray with the height (positive upwards) relative to ocean datum. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_1 + """ s = depth_c * s + (depth - depth_c) * C - return s + eta * (1 + s / depth) + z = s + eta * (1 + s / depth) + + out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) + + z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + + return z.transpose("time", "lev", "lat", "lon") + -@register() def ocean_s_coordinate_g2(s, C, eta, depth, depth_c): + """Ocean s-coordinate, generic form 2. + + Standard name: ocean_s_coordinate_g2 + + Parameters + ---------- + s : xr.DataArray + Dimensionless coordinate. + C : xr.DataArray + Dimensionless vertical coordinate stretching function. + eta : xr.DataArray + Height of the ocean surface (positive upwards) relative to the ocean datum. + depth : xr.DataArray + Distance from ocean datum to sea floor (positive value). + depth_c : xr.DataArray + Constant (positive value) is a critical depth controlling the stretching. + + Returns + ------- + xr.DataArray + A DataArray with the height (positive upwards) relative to ocean datum. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_2 + """ s = (depth_c * s + depth * C) / (depth_c + depth) - return eta + (eta + depth) * s + z = eta + (eta + depth) * s + + out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) + + z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + + return z.transpose("time", "lev", "lat", "lon") + -@register() def ocean_sigma_z_coordinate(sigma, eta, depth, depth_c, nsigma, zlev): + """Ocean sigma over z coordinate. + + Standard name: ocean_sigma_z_coordinate + + Parameters + ---------- + sigma : xr.DataArray + Coordinate defined only for `nsigma` layers nearest the ocean surface. + eta : xr.DataArray + Height of the ocean surface (positive upwards) relative to ocean datum. + depth : xr.DataArray + Distance from ocean datum to sea floor (positive value). + depth_c : xr.DataArray + Constant. + nsigma : xr.DataArray + Layers nearest the ocean surface. + zlev : xr.DataArray + Coordinate defined only for `nlayer - nsigma` where `nlayer` is the size of the vertical coordinate. + + Returns + ------- + xr.DataArray + A DataArray with the height (positive upwards) relative to the ocean datum. + + Notes + ----- + The description of this type of parametric vertical coordinate is defective in version 1.8 and earlier versions of the standard, in that it does not state what values the vertical coordinate variable should contain. Therefore, in accordance with the rules, all versions of the standard before 1.9 are deprecated for datasets that use the "ocean sigma over z" coordinate. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_sigma_over_z_coordinate + """ n, j, i = eta.shape k = sigma.shape[0] - z = np.zeros((n, k, j, i)) + z = xr.DataArray(np.empty((n, k, j, i)), dims=("time", "lev", "lat", "lon")) + + z_sigma = eta + sigma * (np.minimum(depth_c, depth) + eta) - sigma_defined = ~np.isnan(sigma) + z = xr.where(~np.isnan(sigma), z_sigma, z) - zlev_defined = ~np.isnan(zlev) + z = xr.where(np.isnan(sigma), zlev, z) - depth_min = np.minimum(depth_c, depth[np.newaxis, :, :]) + out_stdname = _derive_ocean_stdname( + eta=eta.attrs, depth=depth.attrs, zlev=zlev.attrs + ) - z[:, sigma_defined, :, :] = eta[:, np.newaxis, :, :] + sigma[sigma_defined, np.newaxis, np.newaxis] * (depth_min + eta[:, np.newaxis, :, :]) + z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - z[:, zlev_defined, :, :] = zlev[zlev_defined] + return z.transpose("time", "lev", "lat", "lon") - return z -@register() def ocean_double_sigma_coordinate(sigma, depth, z1, z2, a, href, k_c): + """Ocean double sigma coordinate. + + Standard name: ocean_double_sigma_coordinate + + Parameters + ---------- + sigma : xr.DataArray + Dimensionless coordinate. + depth : xr.DataArray + Distance (positive value) from datum to the sea floor. + z1 : xr.DataArray + Constant with units of length. + z2 : xr.DataArray + Constant with units of length. + a : xr.DataArray + Constant with units of length. + href : xr.DataArray + Constant with units of length. + k_c : xr.DataArray + + Returns + ------- + xr.DataArray + A DataArray with the height (positive upwards) relative to the datum. + + References + ---------- + Please refer to the CF conventions document : + 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_double_sigma_coordinate + """ k = sigma.shape[0] j, i = depth.shape - z = np.zeros((k, j, i)) - f = 0.5 * (z1 + z2) + 0.5 * (z1 - z2) * np.tanh(2 * a / (z1 - z2) * (depth - href)) - above_kc = sigma.k > k_c + z = xr.DataArray(np.empty((k, j, i)), dims=("lev", "lat", "lon"), name="z") + + z = xr.where(sigma.k <= k_c, sigma * f, z) + + z = xr.where(sigma.k > k_c, f + (sigma - 1) * (depth - f), z) - z[above_kc, :, :] = f + (sigma[above_kc] - 1) * (depth[np.newaxis, :, :] - f) + out_stdname = _derive_ocean_stdname(depth=depth.attrs) - z[~above_kc, :, :] = sigma[~above_kc] * f + z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - return z + return z.transpose("lev", "lat", "lon") diff --git a/cf_xarray/tests/test_parametric.py b/cf_xarray/tests/test_parametric.py new file mode 100644 index 00000000..b6ac72af --- /dev/null +++ b/cf_xarray/tests/test_parametric.py @@ -0,0 +1,485 @@ +import numpy as np +import pytest +import xarray as xr +from xarray.testing import assert_allclose + +from cf_xarray import parametric + +ps = xr.DataArray(np.ones((2, 2, 2)), dims=("time", "lat", "lon"), name="ps") + +p0 = xr.DataArray( + [ + 10, + ], + name="p0", +) + +a = xr.DataArray([0, 1, 2], dims=("lev",), name="a") + +b = xr.DataArray([6, 7, 8], dims=("lev",), name="b") + +sigma = xr.DataArray( + [0, 1, 2], dims=("lev",), name="sigma", coords={"k": (("lev",), [0, 1, 2])} +) + +eta = xr.DataArray( + np.ones((2, 2, 2)), + dims=("time", "lat", "lon"), + name="eta", + attrs={"standard_name": "sea_surface_height_above_geoid"}, +) + +depth = xr.DataArray( + np.ones((2, 2)), + dims=("lat", "lon"), + name="depth", + attrs={"standard_name": "sea_floor_depth_below_geoid"}, +) + +depth_c = xr.DataArray([30.0], name="depth_c") + +s = xr.DataArray([0, 1, 2], dims=("lev"), name="s") + + +def test_atmosphere_ln_pressure_coordinate(): + lev = xr.DataArray( + [0, 1, 2], + dims=("lev",), + name="lev", + ) + + output = parametric.atmosphere_ln_pressure_coordinate(p0, lev) + + expected = xr.DataArray([10.0, 3.678794, 1.353353], dims=("lev",), name="p") + + assert_allclose(output, expected) + + assert output.name == "p" + assert output.attrs["standard_name"] == "air_pressure" + + +def test_atmosphere_sigma_coordinate(): + ptop = xr.DataArray([0.98692327], name="ptop") + + output = parametric.atmosphere_sigma_coordinate(sigma, ps, ptop) + + expected = xr.DataArray( + [ + [ + [[0.986923, 0.986923], [0.986923, 0.986923]], + [[1.0, 1.0], [1.0, 1.0]], + [[1.013077, 1.013077], [1.013077, 1.013077]], + ], + [ + [[0.986923, 0.986923], [0.986923, 0.986923]], + [[1.0, 1.0], [1.0, 1.0]], + [[1.013077, 1.013077], [1.013077, 1.013077]], + ], + ], + dims=("time", "lev", "lat", "lon"), + coords={"k": (("lev",), [0, 1, 2])}, + name="p", + ) + + assert_allclose(output, expected) + + assert output.name == "p" + assert output.attrs["standard_name"] == "air_pressure" + + +def test_atmosphere_hybrid_sigma_pressure_coordinate(): + ap = xr.DataArray([3, 4, 5], dims=("lev",), name="ap") + + output = parametric.atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, a=a) + + expected = xr.DataArray( + [ + [ + [[6.0, 6.0], [6.0, 6.0]], + [[17.0, 17.0], [17.0, 17.0]], + [[28.0, 28.0], [28.0, 28.0]], + ], + [ + [[6.0, 6.0], [6.0, 6.0]], + [[17.0, 17.0], [17.0, 17.0]], + [[28.0, 28.0], [28.0, 28.0]], + ], + ], + dims=("time", "lev", "lat", "lon"), + name="p", + ) + + assert_allclose(output, expected) + + assert output.name == "p" + assert output.attrs["standard_name"] == "air_pressure" + + output = parametric.atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, ap=ap) + + expected = xr.DataArray( + [ + [ + [[9.0, 9.0], [9.0, 9.0]], + [[11.0, 11.0], [11.0, 11.0]], + [[13.0, 13.0], [13.0, 13.0]], + ], + [ + [[9.0, 9.0], [9.0, 9.0]], + [[11.0, 11.0], [11.0, 11.0]], + [[13.0, 13.0], [13.0, 13.0]], + ], + ], + dims=("time", "lev", "lat", "lon"), + name="p", + ) + + assert_allclose(output, expected) + + assert output.name == "p" + assert output.attrs["standard_name"] == "air_pressure" + + +def test_atmosphere_hybrid_height_coordinate(): + orog = xr.DataArray( + np.zeros((2, 2, 2)), + dims=("time", "lat", "lon"), + attrs={"standard_name": "surface_altitude"}, + ) + + output = parametric.atmosphere_hybrid_height_coordinate(a, b, orog) + + expected = xr.DataArray( + [ + [ + [[0.0, 0.0], [0.0, 0.0]], + [[1.0, 1.0], [1.0, 1.0]], + [[2.0, 2.0], [2.0, 2.0]], + ], + [ + [[0.0, 0.0], [0.0, 0.0]], + [[1.0, 1.0], [1.0, 1.0]], + [[2.0, 2.0], [2.0, 2.0]], + ], + ], + dims=("time", "lev", "lat", "lon"), + name="p", + ) + + assert_allclose(output, expected) + + assert output.name == "z" + assert output.attrs["standard_name"] == "altitude" + + +def test_atmosphere_sleve_coordinate(): + b1 = xr.DataArray([0, 0, 1], dims=("lev",), name="b1") + + b2 = xr.DataArray([1, 1, 0], dims=("lev",), name="b2") + + ztop = xr.DataArray( + [30.0], + name="ztop", + attrs={"standard_name": "altitude_at_top_of_atmosphere_model"}, + ) + + zsurf1 = xr.DataArray(np.ones((2, 2, 2)), dims=("time", "lat", "lon")) + + zsurf2 = xr.DataArray(np.ones((2, 2, 2)), dims=("time", "lat", "lon")) + + output = parametric.atmosphere_sleve_coordinate(a, b1, b2, ztop, zsurf1, zsurf2) + + expected = xr.DataArray( + [ + [ + [[1.0, 1.0], [1.0, 1.0]], + [[31.0, 31.0], [31.0, 31.0]], + [[61.0, 61.0], [61.0, 61.0]], + ], + [ + [[1.0, 1.0], [1.0, 1.0]], + [[31.0, 31.0], [31.0, 31.0]], + [[61.0, 61.0], [61.0, 61.0]], + ], + ], + dims=("time", "lev", "lat", "lon"), + name="z", + ) + + assert_allclose(output, expected) + + assert output.name == "z" + assert output.attrs["standard_name"] == "altitude" + + +def test_ocean_sigma_coordinate(): + output = parametric.ocean_sigma_coordinate(sigma, eta, depth) + + expected = xr.DataArray( + [ + [ + [[1.0, 1.0], [1.0, 1.0]], + [[3.0, 3.0], [3.0, 3.0]], + [[5.0, 5.0], [5.0, 5.0]], + ], + [ + [[1.0, 1.0], [1.0, 1.0]], + [[3.0, 3.0], [3.0, 3.0]], + [[5.0, 5.0], [5.0, 5.0]], + ], + ], + dims=("time", "lev", "lat", "lon"), + name="z", + coords={"k": (("lev",), [0, 1, 2])}, + ) + + assert_allclose(output, expected) + + assert output.name == "z" + assert output.attrs["standard_name"] == "altitude" + + +def test_ocean_s_coordinate(): + _a = xr.DataArray([1], name="a") + + _b = xr.DataArray([1], name="b") + + output = parametric.ocean_s_coordinate(s, eta, depth, _a, _b, depth_c) + + expected = xr.DataArray( + [ + [ + [[12.403492, 12.403492], [12.403492, 12.403492]], + [[40.434874, 40.434874], [40.434874, 40.434874]], + [[70.888995, 70.888995], [70.888995, 70.888995]], + ], + [ + [[12.403492, 12.403492], [12.403492, 12.403492]], + [[40.434874, 40.434874], [40.434874, 40.434874]], + [[70.888995, 70.888995], [70.888995, 70.888995]], + ], + ], + dims=("time", "lev", "lat", "lon"), + name="z", + ) + + assert_allclose(output, expected) + + assert output.name == "z" + assert output.attrs["standard_name"] == "altitude" + + +def test_ocean_s_coordinate_g1(): + C = xr.DataArray([0, 1, 2], dims=("lev",), name="C") + + output = parametric.ocean_s_coordinate_g1(s, C, eta, depth, depth_c) + + expected = xr.DataArray( + [ + [ + [[1.0, 1.0], [1.0, 1.0]], + [[3.0, 3.0], [3.0, 3.0]], + [[5.0, 5.0], [5.0, 5.0]], + ], + [ + [[1.0, 1.0], [1.0, 1.0]], + [[3.0, 3.0], [3.0, 3.0]], + [[5.0, 5.0], [5.0, 5.0]], + ], + ], + dims=("time", "lev", "lat", "lon"), + name="z", + ) + + assert_allclose(output, expected) + + assert output.name == "z" + assert output.attrs["standard_name"] == "altitude" + + +def test_ocean_s_coordinate_g2(): + C = xr.DataArray([0, 1, 2], dims=("lev",), name="C") + + output = parametric.ocean_s_coordinate_g2(s, C, eta, depth, depth_c) + + expected = xr.DataArray( + [ + [ + [[1.0, 1.0], [1.0, 1.0]], + [[3.0, 3.0], [3.0, 3.0]], + [[5.0, 5.0], [5.0, 5.0]], + ], + [ + [[1.0, 1.0], [1.0, 1.0]], + [[3.0, 3.0], [3.0, 3.0]], + [[5.0, 5.0], [5.0, 5.0]], + ], + ], + dims=("time", "lev", "lat", "lon"), + name="z", + ) + + assert_allclose(output, expected) + + assert output.name == "z" + assert output.attrs["standard_name"] == "altitude" + + +def test_ocean_sigma_z_coordinate(): + zlev = xr.DataArray([0, 1, np.nan], dims=("lev",), name="zlev", attrs={"standard_name": "altitude"}) + + _sigma = xr.DataArray([np.nan, np.nan, 3], dims=("lev",), name="sigma") + + output = parametric.ocean_sigma_z_coordinate(_sigma, eta, depth, depth_c, 10, zlev) + + expected = xr.DataArray( + [ + [ + [[0.0, 0.0], [0.0, 0.0]], + [[1.0, 1.0], [1.0, 1.0]], + [[7.0, 7.0], [7.0, 7.0]], + ], + [ + [[0.0, 0.0], [0.0, 0.0]], + [[1.0, 1.0], [1.0, 1.0]], + [[7.0, 7.0], [7.0, 7.0]], + ], + ], + dims=("time", "lev", "lat", "lon"), + name="z", + ) + + assert_allclose(output, expected) + + assert output.name == "z" + assert output.attrs["standard_name"] == "altitude" + + +def test_ocean_double_sigma_coordinate(): + k_c = xr.DataArray( + [ + 1, + ], + name="k_c", + ) + + href = xr.DataArray( + [ + 20.0, + ], + name="href", + ) + + z1 = xr.DataArray( + [ + 10.0, + ], + name="z1", + ) + + z2 = xr.DataArray( + [ + 30.0, + ], + name="z2", + ) + + a = xr.DataArray( + [ + 2.0, + ], + name="a", + ) + + output = parametric.ocean_double_sigma_coordinate( + sigma, depth, z1, z2, a, href, k_c + ) + + expected = xr.DataArray( + [ + [[0.0, 0.0], [0.0, 0.0]], + [[10.010004, 10.010004], [10.010004, 10.010004]], + [[1.0, 1.0], [1.0, 1.0]], + ], + dims=("lev", "lat", "lon"), + coords={"k": (("lev",), [0, 1, 2])}, + name="z", + ) + + assert_allclose(output, expected) + + assert output.name == "z" + assert output.attrs["standard_name"] == "altitude" + + +@pytest.mark.parametrize("input,expected", [ + ({"zlev": {"standard_name": "altitude"}}, "altitude"), + ({"zlev": {"standard_name": "altitude"}, "eta": {"standard_name": "sea_surface_height_above_geoid"}}, "altitude"), + ({"zlev": {"standard_name": "altitude"}, "eta": {"standard_name": "sea_surface_height_above_geoid"}, "depth": {"standard_name": "sea_floor_depth_below_geoid"}}, "altitude"), + ({"eta": {"standard_name": "sea_surface_height_above_geoid"}, "depth": {"standard_name": "sea_floor_depth_below_geoid"}}, "altitude"), +]) +def test_derive_ocean_stdname(input, expected): + output = parametric._derive_ocean_stdname(**input) + + assert output == expected + + +def test_derive_ocean_stdname_no_values(): + with pytest.raises(ValueError, match="Must provide atleast one of depth, eta, zlev."): + parametric._derive_ocean_stdname() + + +def test_derive_ocean_stdname_empty_value(): + with pytest.raises(ValueError, match="The values for zlev cannot be `None`."): + parametric._derive_ocean_stdname(zlev=None) + + +def test_derive_ocean_stdname_no_standard_name(): + with pytest.raises(ValueError, match="The standard name for the 'zlev' variable is not available."): + parametric._derive_ocean_stdname(zlev={}) + + +def test_derive_ocean_stdname_no_match(): + with pytest.raises(ValueError, match="Could not derive standard name from combination of not in any list."): + parametric._derive_ocean_stdname(zlev={"standard_name": "not in any list"}) + + +def test_func_from_stdname(): + with pytest.raises(AttributeError): + parametric.func_from_stdname("test") + + func = parametric.func_from_stdname("atmosphere_ln_pressure_coordinate") + + assert func == parametric.atmosphere_ln_pressure_coordinate + + +def test_check_requirements(): + with pytest.raises(KeyError, match="'Required terms lev, p0 absent in dataset.'"): + parametric.check_requirements(parametric.atmosphere_ln_pressure_coordinate, []) + + parametric.check_requirements( + parametric.atmosphere_ln_pressure_coordinate, ["p0", "lev"] + ) + + with pytest.raises( + KeyError, + match=r"'Required terms b, p0 and atleast one optional term a, ap absent in dataset.'", + ): + parametric.check_requirements( + parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps"] + ) + + with pytest.raises( + KeyError, + match="'Atleast one of the optional terms a, ap is absent in dataset.'", + ): + parametric.check_requirements( + parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "p0", "b"] + ) + + with pytest.raises( + KeyError, + match="'Required terms b and atleast one optional term a, ap absent in dataset.'", + ): + parametric.check_requirements( + parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "p0", "a"] + ) From 49838e687a55f3588ffe664ad9dfc3b536c78fa5 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Mon, 17 Jun 2024 19:41:48 -0700 Subject: [PATCH 03/22] Adds function to help transpose outputs --- cf_xarray/parametric.py | 91 +++++++++++++++++++++++++----- cf_xarray/tests/test_accessor.py | 6 +- cf_xarray/tests/test_parametric.py | 66 ++++++++++++++++++---- 3 files changed, 137 insertions(+), 26 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index f9e128fd..27a4cf44 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -61,7 +61,6 @@ def _derive_ocean_stdname(**kwargs): Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#table-computed-standard-names """ - found_stdname = None allowed_names = {"zlev", "eta", "depth"} @@ -189,6 +188,48 @@ def func_from_stdname(stdname): return getattr(m, stdname) +def derive_dimension_order(output_order, **dim_map): + """Derive dimension ordering from input map. + + This will derive a dimensinal ordering from a map of dimension + identifiers and variables containing the dimensions. + + This is useful when dimension names are not know. + + For example if the desired output ordering was "nkji" where + variable "A" contains "nji" (time, lat, lon) and "B" contains + "k" (height) then the output would be (time, height, lat, lon). + + This also works when dimensions are missing. + + For example if the desired output ordering was "nkji" where + variable "A" contains "n" (time) and "B" contains + "k" (height) then the output would be (time, height). + + Parameters + ---------- + output_order : str + Dimension identifiers in desired order, e.g. "nkji". + **dim_map : dict + Dimension identifiers and variable containing them, e.g. "nji": eta, "k": s. + + Returns + ------- + list + Output dimensions in desired order. + """ + dims = {} + + for x, y in dim_map.items(): + for i, z in enumerate(x): + try: + dims[z] = y.dims[i] + except IndexError: + dims[z] = None + + return tuple(dims[x] for x in list(output_order) if dims[x] is not None) + + def atmosphere_ln_pressure_coordinate(p0, lev): """Atmosphere natural log pressure coordinate. @@ -244,7 +285,9 @@ def atmosphere_sigma_coordinate(sigma, ps, ptop): p = p.squeeze().rename("p").assign_attrs(standard_name="air_pressure") - return p.transpose("time", "lev", "lat", "lon") + output_order = derive_dimension_order("nkji", nji=ps, k=sigma) + + return p.transpose(*output_order) def atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, a=None, ap=None): @@ -282,7 +325,9 @@ def atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, a=None, ap=None): p = p.squeeze().rename("p").assign_attrs(standard_name="air_pressure") - return p.transpose("time", "lev", "lat", "lon") + output_order = derive_dimension_order("nkji", nji=ps, k=b) + + return p.transpose(*output_order) def atmosphere_hybrid_height_coordinate(a, b, orog): @@ -320,7 +365,9 @@ def atmosphere_hybrid_height_coordinate(a, b, orog): z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - return z.transpose("time", "lev", "lat", "lon") + output_order = derive_dimension_order("nkji", nji=orog, k=b) + + return z.transpose(*output_order) def atmosphere_sleve_coordinate(a, b1, b2, ztop, zsurf1, zsurf2): @@ -364,7 +411,9 @@ def atmosphere_sleve_coordinate(a, b1, b2, ztop, zsurf1, zsurf2): z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - return z.transpose("time", "lev", "lat", "lon") + output_order = derive_dimension_order("nkji", nji=zsurf1, k=a) + + return z.transpose(*output_order) def ocean_sigma_coordinate(sigma, eta, depth): @@ -397,7 +446,9 @@ def ocean_sigma_coordinate(sigma, eta, depth): z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - return z.transpose("time", "lev", "lat", "lon") + output_order = derive_dimension_order("nkji", nji=eta, k=sigma) + + return z.transpose(*output_order) def ocean_s_coordinate(s, eta, depth, a, b, depth_c): @@ -440,7 +491,9 @@ def ocean_s_coordinate(s, eta, depth, a, b, depth_c): z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - return z.transpose("time", "lev", "lat", "lon") + output_order = derive_dimension_order("nkji", nji=eta, k=s) + + return z.transpose(*output_order) def ocean_s_coordinate_g1(s, C, eta, depth, depth_c): @@ -471,15 +524,17 @@ def ocean_s_coordinate_g1(s, C, eta, depth, depth_c): Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_1 """ - s = depth_c * s + (depth - depth_c) * C + S = depth_c * s + (depth - depth_c) * C - z = s + eta * (1 + s / depth) + z = S + eta * (1 + s / depth) out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - return z.transpose("time", "lev", "lat", "lon") + output_order = derive_dimension_order("nkji", nji=eta, k=s) + + return z.transpose(*output_order) def ocean_s_coordinate_g2(s, C, eta, depth, depth_c): @@ -510,15 +565,17 @@ def ocean_s_coordinate_g2(s, C, eta, depth, depth_c): Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_2 """ - s = (depth_c * s + depth * C) / (depth_c + depth) + S = (depth_c * s + depth * C) / (depth_c + depth) - z = eta + (eta + depth) * s + z = eta + (eta + depth) * S out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - return z.transpose("time", "lev", "lat", "lon") + output_order = derive_dimension_order("nkji", nji=eta, k=s) + + return z.transpose(*output_order) def ocean_sigma_z_coordinate(sigma, eta, depth, depth_c, nsigma, zlev): @@ -573,7 +630,9 @@ def ocean_sigma_z_coordinate(sigma, eta, depth, depth_c, nsigma, zlev): z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - return z.transpose("time", "lev", "lat", "lon") + output_order = derive_dimension_order("nkji", nji=eta, k=sigma) + + return z.transpose(*output_order) def ocean_double_sigma_coordinate(sigma, depth, z1, z2, a, href, k_c): @@ -623,4 +682,6 @@ def ocean_double_sigma_coordinate(sigma, depth, z1, z2, a, href, k_c): z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - return z.transpose("lev", "lat", "lon") + output_order = derive_dimension_order("kji", ji=depth, k=sigma) + + return z.transpose(*output_order) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 4f03f616..3541804a 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -1309,12 +1309,16 @@ def test_decode_vertical_coords() -> None: ): romsds.cf.decode_vertical_coords() + # needs standard names on `eta` and `depth` to derive computed standard name + romsds.h.attrs["standard_name"] = "sea_floor_depth_below_ geopotential_datum" + romsds.zeta.attrs["standard_name"] = "sea_surface_height_above_ geopotential_datum" + with pytest.warns(DeprecationWarning): romsds.cf.decode_vertical_coords(prefix="z_rho") romsds_less_h = romsds.drop_vars(["h"]) - with pytest.raises(KeyError, match="Required terms {'depth'} absent in dataset."): + with pytest.raises(KeyError, match="Required terms depth absent in dataset."): romsds_less_h.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"}) diff --git a/cf_xarray/tests/test_parametric.py b/cf_xarray/tests/test_parametric.py index b6ac72af..d7356520 100644 --- a/cf_xarray/tests/test_parametric.py +++ b/cf_xarray/tests/test_parametric.py @@ -41,6 +41,21 @@ s = xr.DataArray([0, 1, 2], dims=("lev"), name="s") +@pytest.mark.parametrize( + "order,kwargs,expected", + [ + ("nkji", {"nji": ps, "k": a}, ("time", "lev", "lat", "lon")), + ("nkij", {"nji": ps, "k": a}, ("time", "lev", "lon", "lat")), + ("ijn", {"nji": ps}, ("lon", "lat", "time")), + ("njk", {"nji": ps, "k": a}, ("time", "lat", "lev")), + ], +) +def test_derive_dimension_order(order, kwargs, expected): + order = parametric.derive_dimension_order(order, **kwargs) + + assert order == expected + + def test_atmosphere_ln_pressure_coordinate(): lev = xr.DataArray( [0, 1, 2], @@ -325,7 +340,9 @@ def test_ocean_s_coordinate_g2(): def test_ocean_sigma_z_coordinate(): - zlev = xr.DataArray([0, 1, np.nan], dims=("lev",), name="zlev", attrs={"standard_name": "altitude"}) + zlev = xr.DataArray( + [0, 1, np.nan], dims=("lev",), name="zlev", attrs={"standard_name": "altitude"} + ) _sigma = xr.DataArray([np.nan, np.nan, 3], dims=("lev",), name="sigma") @@ -411,12 +428,34 @@ def test_ocean_double_sigma_coordinate(): assert output.attrs["standard_name"] == "altitude" -@pytest.mark.parametrize("input,expected", [ - ({"zlev": {"standard_name": "altitude"}}, "altitude"), - ({"zlev": {"standard_name": "altitude"}, "eta": {"standard_name": "sea_surface_height_above_geoid"}}, "altitude"), - ({"zlev": {"standard_name": "altitude"}, "eta": {"standard_name": "sea_surface_height_above_geoid"}, "depth": {"standard_name": "sea_floor_depth_below_geoid"}}, "altitude"), - ({"eta": {"standard_name": "sea_surface_height_above_geoid"}, "depth": {"standard_name": "sea_floor_depth_below_geoid"}}, "altitude"), -]) +@pytest.mark.parametrize( + "input,expected", + [ + ({"zlev": {"standard_name": "altitude"}}, "altitude"), + ( + { + "zlev": {"standard_name": "altitude"}, + "eta": {"standard_name": "sea_surface_height_above_geoid"}, + }, + "altitude", + ), + ( + { + "zlev": {"standard_name": "altitude"}, + "eta": {"standard_name": "sea_surface_height_above_geoid"}, + "depth": {"standard_name": "sea_floor_depth_below_geoid"}, + }, + "altitude", + ), + ( + { + "eta": {"standard_name": "sea_surface_height_above_geoid"}, + "depth": {"standard_name": "sea_floor_depth_below_geoid"}, + }, + "altitude", + ), + ], +) def test_derive_ocean_stdname(input, expected): output = parametric._derive_ocean_stdname(**input) @@ -424,7 +463,9 @@ def test_derive_ocean_stdname(input, expected): def test_derive_ocean_stdname_no_values(): - with pytest.raises(ValueError, match="Must provide atleast one of depth, eta, zlev."): + with pytest.raises( + ValueError, match="Must provide atleast one of depth, eta, zlev." + ): parametric._derive_ocean_stdname() @@ -434,12 +475,17 @@ def test_derive_ocean_stdname_empty_value(): def test_derive_ocean_stdname_no_standard_name(): - with pytest.raises(ValueError, match="The standard name for the 'zlev' variable is not available."): + with pytest.raises( + ValueError, match="The standard name for the 'zlev' variable is not available." + ): parametric._derive_ocean_stdname(zlev={}) def test_derive_ocean_stdname_no_match(): - with pytest.raises(ValueError, match="Could not derive standard name from combination of not in any list."): + with pytest.raises( + ValueError, + match="Could not derive standard name from combination of not in any list.", + ): parametric._derive_ocean_stdname(zlev={"standard_name": "not in any list"}) From 2bacb6f5fc0b2a5f89949597fe401df76095b3cb Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 17 Jul 2024 10:32:39 -0700 Subject: [PATCH 04/22] Removes hardcoded dims --- cf_xarray/parametric.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 27a4cf44..8e26ba7e 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -612,11 +612,15 @@ def ocean_sigma_z_coordinate(sigma, eta, depth, depth_c, nsigma, zlev): Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_sigma_over_z_coordinate """ - n, j, i = eta.shape + z_shape = list(eta.shape) - k = sigma.shape[0] + z_shape.insert(1, sigma.shape[0]) - z = xr.DataArray(np.empty((n, k, j, i)), dims=("time", "lev", "lat", "lon")) + z_dims = list(eta.dims) + + z_dims.insert(1, sigma.dims[0]) + + z = xr.DataArray(np.empty(z_shape), dims=z_dims) z_sigma = eta + sigma * (np.minimum(depth_c, depth) + eta) @@ -666,13 +670,14 @@ def ocean_double_sigma_coordinate(sigma, depth, z1, z2, a, href, k_c): Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_double_sigma_coordinate """ - k = sigma.shape[0] + f = 0.5 * (z1 + z2) + 0.5 * (z1 - z2) * np.tanh(2 * a / (z1 - z2) * (depth - href)) - j, i = depth.shape + # shape k, j, i + z_shape = sigma.shape + depth.shape - f = 0.5 * (z1 + z2) + 0.5 * (z1 - z2) * np.tanh(2 * a / (z1 - z2) * (depth - href)) + z_dims = sigma.dims + depth.dims - z = xr.DataArray(np.empty((k, j, i)), dims=("lev", "lat", "lon"), name="z") + z = xr.DataArray(np.empty(z_shape), dims=z_dims, name="z") z = xr.where(sigma.k <= k_c, sigma * f, z) From e37173aa170eb027114e10cc6ae2ba844432ed9a Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Thu, 18 Jul 2024 10:31:32 -0700 Subject: [PATCH 05/22] Fixes optional argument bug and handling case insensitive terms --- cf_xarray/accessor.py | 3 ++- cf_xarray/parametric.py | 27 +++++++++++++-------------- cf_xarray/tests/test_accessor.py | 4 +++- cf_xarray/tests/test_parametric.py | 16 +++++++++++++--- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 83f2cc77..d7717699 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -2705,7 +2705,8 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None): f"Variable {value!r} is required to decode coordinate for {dim!r}" " but it is absent in the Dataset." ) - terms[key] = ds[value] + # keys should be case insensitive + terms[key.lower()] = ds[value] try: func = parametric.func_from_stdname(stdname) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 8e26ba7e..f97ce423 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -140,6 +140,9 @@ def check_requirements(func, terms): if len(terms) == 0: raise KeyError(f"Required terms {', '.join(sorted(args))} absent in dataset.") + # handle case insensitive + terms = {x.lower() for x in terms} + n = len(spec.defaults or []) # last `n` arguments are optional @@ -151,20 +154,16 @@ def check_requirements(func, terms): if (req & terms) != req: req_diff = sorted(req - terms) - opt_err = "" - - # optional is not present in terms - if len(opt) > 0 and not (opt <= terms): - opt_err = f"and atleast one optional term {', '.join(sorted(opt))} " - raise KeyError( - f"Required terms {', '.join(req_diff)} {opt_err}absent in dataset." + f"Required terms {', '.join(req_diff)} are absent in the dataset." ) - # atleast one optional is in diff, only required for atmoshphere hybrid sigma pressure coordinate - if len(opt) > 0 and not (opt <= terms): + # if there are optional arguments check that atleast one + # is in the intersection, only required for + # atmosphere_hybrid_sigma_pressure_coordinate + if len(opt) > 0 and len(opt & terms) == 0: raise KeyError( - f"Atleast one of the optional terms {', '.join(sorted(opt))} is absent in dataset." + f"Atleast one optional term {', '.join(sorted(opt))} is absent in the dataset." ) @@ -496,7 +495,7 @@ def ocean_s_coordinate(s, eta, depth, a, b, depth_c): return z.transpose(*output_order) -def ocean_s_coordinate_g1(s, C, eta, depth, depth_c): +def ocean_s_coordinate_g1(s, c, eta, depth, depth_c): """Ocean s-coordinate, generic form 1. Standard name: ocean_s_coordinate_g1 @@ -524,7 +523,7 @@ def ocean_s_coordinate_g1(s, C, eta, depth, depth_c): Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_1 """ - S = depth_c * s + (depth - depth_c) * C + S = depth_c * s + (depth - depth_c) * c z = S + eta * (1 + s / depth) @@ -537,7 +536,7 @@ def ocean_s_coordinate_g1(s, C, eta, depth, depth_c): return z.transpose(*output_order) -def ocean_s_coordinate_g2(s, C, eta, depth, depth_c): +def ocean_s_coordinate_g2(s, c, eta, depth, depth_c): """Ocean s-coordinate, generic form 2. Standard name: ocean_s_coordinate_g2 @@ -565,7 +564,7 @@ def ocean_s_coordinate_g2(s, C, eta, depth, depth_c): Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_2 """ - S = (depth_c * s + depth * C) / (depth_c + depth) + S = (depth_c * s + depth * c) / (depth_c + depth) z = eta + (eta + depth) * S diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 3541804a..3298145d 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -1318,7 +1318,9 @@ def test_decode_vertical_coords() -> None: romsds_less_h = romsds.drop_vars(["h"]) - with pytest.raises(KeyError, match="Required terms depth absent in dataset."): + with pytest.raises( + KeyError, match="Required terms depth are absent in the dataset." + ): romsds_less_h.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"}) diff --git a/cf_xarray/tests/test_parametric.py b/cf_xarray/tests/test_parametric.py index d7356520..392e4afd 100644 --- a/cf_xarray/tests/test_parametric.py +++ b/cf_xarray/tests/test_parametric.py @@ -508,7 +508,7 @@ def test_check_requirements(): with pytest.raises( KeyError, - match=r"'Required terms b, p0 and atleast one optional term a, ap absent in dataset.'", + match=r"'Required terms b, p0 are absent in the dataset.'", ): parametric.check_requirements( parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps"] @@ -516,7 +516,7 @@ def test_check_requirements(): with pytest.raises( KeyError, - match="'Atleast one of the optional terms a, ap is absent in dataset.'", + match="'Atleast one optional term a, ap is absent in the dataset.'", ): parametric.check_requirements( parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "p0", "b"] @@ -524,8 +524,18 @@ def test_check_requirements(): with pytest.raises( KeyError, - match="'Required terms b and atleast one optional term a, ap absent in dataset.'", + match="'Required terms b are absent in the dataset.'", ): parametric.check_requirements( parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "p0", "a"] ) + + # Should pass + parametric.check_requirements( + parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "p0", "b", "a"] + ) + + # check case insensitive + parametric.check_requirements( + parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "P0", "b", "A"] + ) From 24aa5dfcb09e95be240e0b51c4abc992aa0bc129 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Sat, 17 Aug 2024 00:36:03 -0700 Subject: [PATCH 06/22] Moves from function to class based implementation of transforms --- cf_xarray/accessor.py | 9 +- cf_xarray/parametric.py | 727 ++++++++++++++++++----------- cf_xarray/tests/test_parametric.py | 372 +++++++++------ 3 files changed, 680 insertions(+), 428 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index cf56ca03..da23c26a 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -2789,17 +2789,14 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None): terms[key.lower()] = ds[value] try: - func = parametric.func_from_stdname(stdname) - except AttributeError: + transform = parametric.TRANSFORM_FROM_STDNAME[stdname] + except KeyError: # Should occur since stdname is check before raise NotImplementedError( f"Coordinate function for {stdname!r} not implmented yet. Contributions welcome!" ) from None - # let KeyError propagate - parametric.check_requirements(func, terms) - - ds.coords[zname] = func(**terms) + ds.coords[zname] = transform.from_terms(terms) @xr.register_dataarray_accessor("cf") diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index f97ce423..95aacf0d 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -1,8 +1,9 @@ -import inspect -import sys +from abc import ABC, abstractmethod +from collections.abc import Sequence import numpy as np import xarray as xr +from xarray import DataArray ocean_stdname_map = { "altitude": { @@ -110,126 +111,23 @@ def _derive_ocean_stdname(**kwargs): return found_stdname -def check_requirements(func, terms): - """Checks terms against function requirements. +class ParamerticVerticalCoordinate(ABC): + @classmethod + @abstractmethod + def from_terms(cls, terms: dict): + pass - Uses `func` argument specification as requirements and checks terms against this. - Postitional arguments without a default are required but when a default value is - provided the arguement is considered optional. Atleast one optional argument must - be present (special case for atmosphere_hybrid_sigma_pressure_coordinate). + @abstractmethod + def decode(self): + pass - Parameters - ---------- - func : function - Function to check requirements. - terms : list - List of terms to check `func` requirements against. - - Raises - ------ - KeyError - If `terms` is empty or missing required/optional terms. - """ - if not isinstance(terms, set): - terms = set(terms) - - spec = inspect.getfullargspec(func) - - args = spec.args or [] - - if len(terms) == 0: - raise KeyError(f"Required terms {', '.join(sorted(args))} absent in dataset.") - - # handle case insensitive - terms = {x.lower() for x in terms} - - n = len(spec.defaults or []) - - # last `n` arguments are optional - opt = set(args[len(args) - n :]) - - req = set(args) - opt - - # req must all be present in terms - if (req & terms) != req: - req_diff = sorted(req - terms) - - raise KeyError( - f"Required terms {', '.join(req_diff)} are absent in the dataset." - ) - - # if there are optional arguments check that atleast one - # is in the intersection, only required for - # atmosphere_hybrid_sigma_pressure_coordinate - if len(opt) > 0 and len(opt & terms) == 0: - raise KeyError( - f"Atleast one optional term {', '.join(sorted(opt))} is absent in the dataset." - ) - - -def func_from_stdname(stdname): - """Get function from module. - - Uses `stdname` to return function from module. - - Parameters - ---------- - stdname : str - Name of the function. - - Raises - ------ - AttributeError - If a function name `stdname` is not in the module. - """ - m = sys.modules[__name__] - - return getattr(m, stdname) - - -def derive_dimension_order(output_order, **dim_map): - """Derive dimension ordering from input map. - - This will derive a dimensinal ordering from a map of dimension - identifiers and variables containing the dimensions. - - This is useful when dimension names are not know. - - For example if the desired output ordering was "nkji" where - variable "A" contains "nji" (time, lat, lon) and "B" contains - "k" (height) then the output would be (time, height, lat, lon). + @property + @abstractmethod + def computed_standard_name(self): + pass - This also works when dimensions are missing. - For example if the desired output ordering was "nkji" where - variable "A" contains "n" (time) and "B" contains - "k" (height) then the output would be (time, height). - - Parameters - ---------- - output_order : str - Dimension identifiers in desired order, e.g. "nkji". - **dim_map : dict - Dimension identifiers and variable containing them, e.g. "nji": eta, "k": s. - - Returns - ------- - list - Output dimensions in desired order. - """ - dims = {} - - for x, y in dim_map.items(): - for i, z in enumerate(x): - try: - dims[z] = y.dims[i] - except IndexError: - dims[z] = None - - return tuple(dims[x] for x in list(output_order) if dims[x] is not None) - - -def atmosphere_ln_pressure_coordinate(p0, lev): +class AtmosphereLnPressure(ParamerticVerticalCoordinate): """Atmosphere natural log pressure coordinate. Standard name: atmosphere_ln_pressure_coordinate @@ -241,24 +139,42 @@ def atmosphere_ln_pressure_coordinate(p0, lev): lev : xr.DataArray Vertical dimensionless coordinate. - Returns - ------- - xr.DataArray - A DataArray with new pressure coordinate. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#atmosphere-natural-log-pressure-coordinate """ - p = p0 * np.exp(-lev) - p = p.squeeze().rename("p").assign_attrs(standard_name="air_pressure") + def __init__(self, p0, lev): + self.p0 = p0 + self.lev = lev + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + p = self.p0 * np.exp(-self.lev) + + return p.squeeze().assign_attrs(standard_name=self.computed_standard_name) - return p + @property + def computed_standard_name(self): + """Computes coordinate standard name.""" + return "air_pressure" + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + p0, lev = get_terms(terms, "p0", "lev") -def atmosphere_sigma_coordinate(sigma, ps, ptop): + return cls(p0, lev) + + +class AtmosphereSigma(ParamerticVerticalCoordinate): """Atmosphere sigma coordinate. Standard name: atmosphere_sigma_coordinate @@ -270,26 +186,43 @@ def atmosphere_sigma_coordinate(sigma, ps, ptop): ps : xr.DataArray Horizontal surface pressure. - Returns - ------- - xr.DataArray - A DataArray with new pressure coordinate. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_atmosphere_sigma_coordinate """ - p = ptop + sigma * (ps - ptop) - p = p.squeeze().rename("p").assign_attrs(standard_name="air_pressure") + def __init__(self, sigma, ps, ptop): + self.sigma = sigma + self.ps = ps + self.ptop = ptop + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + p = self.ptop + self.sigma * (self.ps - self.ptop) + + return p.squeeze().assign_attrs(standard_name=self.computed_standard_name) + + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + return "air_pressure" - output_order = derive_dimension_order("nkji", nji=ps, k=sigma) + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + sigma, ps, ptop = get_terms(terms, "sigma", "ps", "ptop") - return p.transpose(*output_order) + return cls(sigma, ps, ptop) -def atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, a=None, ap=None): +class AtmosphereHybridSigmaPressure(ParamerticVerticalCoordinate): """Atmosphere hybrid sigma pressure coordinate. Standard name: atmosphere_hybrid_sigma_pressure_coordinate @@ -307,29 +240,63 @@ def atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, a=None, ap=None): ap : xr.DataArray Component of hybrid coordinate. - Returns - ------- - xr.DataArray - A DataArray with new pressure coordinate. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_atmosphere_hybrid_sigma_pressure_coordinate """ - if a is None: - p = ap + b * ps - else: - p = a * p0 + b * ps - - p = p.squeeze().rename("p").assign_attrs(standard_name="air_pressure") - output_order = derive_dimension_order("nkji", nji=ps, k=b) - - return p.transpose(*output_order) - - -def atmosphere_hybrid_height_coordinate(a, b, orog): + def __init__(self, b, ps, p0=None, a=None, ap=None): + self.b = b + self.ps = ps + self.p0 = p0 + self.a = a + self.ap = ap + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + if self.a is None: + p = self.ap + self.b * self.ps + else: + p = self.a * self.p0 + self.b * self.ps + + return p.squeeze().assign_attrs(standard_name=self.computed_standard_name) + + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + return "air_pressure" + + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + b, ps, p0, a, ap = get_terms(terms, "b", "ps", optional=("p0", "a", "ap")) + + if a is None and ap is None: + raise KeyError( + "Optional terms 'a', 'ap' are absent in the dataset, atleast one must be present." + ) + + if a is not None and ap is not None: + raise Exception( + "Both optional terms 'a' and 'ap' are present in the dataset, please drop one of them." + ) + + if a is not None and p0 is None: + raise KeyError( + "Optional term 'a' is present but 'p0' is absent in the dataset." + ) + + return cls(b, ps, p0, a, ap) + + +class AtmosphereHybridHeight(ParamerticVerticalCoordinate): """Atmosphere hybrid height coordinate. Standard name: atmosphere_hybrid_height_coordinate @@ -343,33 +310,50 @@ def atmosphere_hybrid_height_coordinate(a, b, orog): orog : xr.DataArray Height of the surface above the datum. - Returns - ------- - xr.DataArray - A DataArray with the height above the datum. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#atmosphere-hybrid-height-coordinate """ - z = a + b * orog - orog_stdname = orog.attrs["standard_name"] + def __init__(self, a, b, orog): + self.a = a + self.b = b + self.orog = orog + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + z = self.a + self.b * self.orog + + return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) - if orog_stdname == "surface_altitude": - out_stdname = "altitude" - elif orog_stdname == "surface_height_above_geopotential_datum": - out_stdname = "height_above_geopotential_datum" + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + orog_stdname = self.orog.attrs["standard_name"] - z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + if orog_stdname == "surface_altitude": + out_stdname = "altitude" + elif orog_stdname == "surface_height_above_geopotential_datum": + out_stdname = "height_above_geopotential_datum" - output_order = derive_dimension_order("nkji", nji=orog, k=b) + return out_stdname - return z.transpose(*output_order) + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + a, b, orog = get_terms(terms, "a", "b", "orog") + return cls(a, b, orog) -def atmosphere_sleve_coordinate(a, b1, b2, ztop, zsurf1, zsurf2): + +class AtmosphereSleve(ParamerticVerticalCoordinate): """Atmosphere smooth level vertical (SLEVE) coordinate. Standard name: atmosphere_sleve_coordinate @@ -389,33 +373,57 @@ def atmosphere_sleve_coordinate(a, b1, b2, ztop, zsurf1, zsurf2): zsurf2 : xr.DataArray Small-scale component of the topography. - Returns - ------- - xr.DataArray - A DataArray with the height above the datum. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_atmosphere_smooth_level_vertical_sleve_coordinate """ - z = a * ztop + b1 * zsurf1 + b2 * zsurf2 - ztop_stdname = ztop.attrs["standard_name"] - - if ztop_stdname == "altitude_at_top_of_atmosphere_model": - out_stdname = "altitude" - elif ztop_stdname == "height_above_geopotential_datum_at_top_of_atmosphere_model": - out_stdname = "height_above_geopotential_datum" - - z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - - output_order = derive_dimension_order("nkji", nji=zsurf1, k=a) + def __init__(self, a, b1, b2, ztop, zsurf1, zsurf2): + self.a = a + self.b1 = b1 + self.b2 = b2 + self.ztop = ztop + self.zsurf1 = zsurf1 + self.zsurf2 = zsurf2 + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + z = self.a * self.ztop + self.b1 * self.zsurf1 + self.b2 * self.zsurf2 + + return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + ztop_stdname = self.ztop.attrs["standard_name"] + + if ztop_stdname == "altitude_at_top_of_atmosphere_model": + out_stdname = "altitude" + elif ( + ztop_stdname == "height_above_geopotential_datum_at_top_of_atmosphere_model" + ): + out_stdname = "height_above_geopotential_datum" + + return out_stdname + + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + a, b1, b2, ztop, zsurf1, zsurf2 = get_terms( + terms, "a", "b1", "b2", "ztop", "zsurf1", "zsurf2" + ) - return z.transpose(*output_order) + return cls(a, b1, b2, ztop, zsurf1, zsurf2) -def ocean_sigma_coordinate(sigma, eta, depth): +class OceanSigma(ParamerticVerticalCoordinate): """Ocean sigma coordinate. Standard name: ocean_sigma_coordinate @@ -429,28 +437,45 @@ def ocean_sigma_coordinate(sigma, eta, depth): depth : xr.DataArray Distance (positive value) from the datum to the sea floor. - Returns - ------- - xr.DataArray - A DataArray with the height (positive upwards) relative to the datum. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_sigma_coordinate """ - z = eta + sigma * (depth + eta) - out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) + def __init__(self, sigma, eta, depth): + self.sigma = sigma + self.eta = eta + self.depth = depth + + def decode(self) -> xr.DataArray: + """Decode coordinate. - z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + z = self.eta + self.sigma * (self.depth + self.eta) - output_order = derive_dimension_order("nkji", nji=eta, k=sigma) + return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) - return z.transpose(*output_order) + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + out_stdname = _derive_ocean_stdname(eta=self.eta.attrs, depth=self.depth.attrs) + return out_stdname -def ocean_s_coordinate(s, eta, depth, a, b, depth_c): + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + sigma, eta, depth = get_terms(terms, "sigma", "eta", "depth") + + return cls(sigma, eta, depth) + + +class OceanS(ParamerticVerticalCoordinate): """Ocean s-coordinate. Standard name: ocean_s_coordinate @@ -470,32 +495,56 @@ def ocean_s_coordinate(s, eta, depth, a, b, depth_c): depth_c : xr.DataArray Constant controlling stretch. - Returns - ------- - xr.DataArray - A DataArray with the height (positive upwards) relative to the datum. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate """ - C = (1 - b) * np.sinh(a * s) / np.sinh(a) + b * ( - np.tanh(a * (s + 0.5)) / 2 * np.tanh(0.5 * a) - 0.5 - ) - z = eta * (1 + s) + depth_c * s + (depth - depth_c) * C + def __init__(self, s, eta, depth, a, b, depth_c): + self.s = s + self.eta = eta + self.depth = depth + self.a = a + self.b = b + self.depth_c = depth_c + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + C = (1 - self.b) * np.sinh(self.a * self.s) / np.sinh(self.a) + self.b * ( + np.tanh(self.a * (self.s + 0.5)) / 2 * np.tanh(0.5 * self.a) - 0.5 + ) - out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) + z = ( + self.eta * (1 + self.s) + + self.depth_c * self.s + + (self.depth - self.depth_c) * C + ) - z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) - output_order = derive_dimension_order("nkji", nji=eta, k=s) + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + return _derive_ocean_stdname(eta=self.eta.attrs, depth=self.depth.attrs) - return z.transpose(*output_order) + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + s, eta, depth, a, b, depth_c = get_terms( + terms, "s", "eta", "depth", "a", "b", "depth_c" + ) + return cls(s, eta, depth, a, b, depth_c) -def ocean_s_coordinate_g1(s, c, eta, depth, depth_c): + +class OceanSG1(ParamerticVerticalCoordinate): """Ocean s-coordinate, generic form 1. Standard name: ocean_s_coordinate_g1 @@ -513,30 +562,49 @@ def ocean_s_coordinate_g1(s, c, eta, depth, depth_c): depth_c : xr.DataArray Constant (positive value) is a critical depth controlling the stretching. - Returns - ------- - xr.DataArray - A DataArray with the height (positive upwards) relative to ocean datum. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_1 """ - S = depth_c * s + (depth - depth_c) * c - z = S + eta * (1 + s / depth) - - out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) - - z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - - output_order = derive_dimension_order("nkji", nji=eta, k=s) + def __init__(self, s, c, eta, depth, depth_c): + self.s = s + self.c = c + self.eta = eta + self.depth = depth + self.depth_c = depth_c + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + S = self.depth_c * self.s + (self.depth - self.depth_c) * self.c + + z = S + self.eta * (1 + self.s / self.depth) + + return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + return _derive_ocean_stdname(eta=self.eta.attrs, depth=self.depth.attrs) + + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + s, c, eta, depth, depth_c = get_terms( + terms, "s", "c", "eta", "depth", "depth_c" + ) - return z.transpose(*output_order) + return cls(s, c, eta, depth, depth_c) -def ocean_s_coordinate_g2(s, c, eta, depth, depth_c): +class OceanSG2(ParamerticVerticalCoordinate): """Ocean s-coordinate, generic form 2. Standard name: ocean_s_coordinate_g2 @@ -554,30 +622,49 @@ def ocean_s_coordinate_g2(s, c, eta, depth, depth_c): depth_c : xr.DataArray Constant (positive value) is a critical depth controlling the stretching. - Returns - ------- - xr.DataArray - A DataArray with the height (positive upwards) relative to ocean datum. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_2 """ - S = (depth_c * s + depth * c) / (depth_c + depth) - - z = eta + (eta + depth) * S - - out_stdname = _derive_ocean_stdname(eta=eta.attrs, depth=depth.attrs) - z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) - - output_order = derive_dimension_order("nkji", nji=eta, k=s) + def __init__(self, s, c, eta, depth, depth_c): + self.s = s + self.c = c + self.eta = eta + self.depth = depth + self.depth_c = depth_c + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + S = (self.depth_c * self.s + self.depth * self.c) / (self.depth_c + self.depth) + + z = self.eta + (self.eta + self.depth) * S + + return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + return _derive_ocean_stdname(eta=self.eta.attrs, depth=self.depth.attrs) + + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + s, c, eta, depth, depth_c = get_terms( + terms, "s", "c", "eta", "depth", "depth_c" + ) - return z.transpose(*output_order) + return cls(s, c, eta, depth, depth_c) -def ocean_sigma_z_coordinate(sigma, eta, depth, depth_c, nsigma, zlev): +class OceanSigmaZ(ParamerticVerticalCoordinate): """Ocean sigma over z coordinate. Standard name: ocean_sigma_z_coordinate @@ -597,11 +684,6 @@ def ocean_sigma_z_coordinate(sigma, eta, depth, depth_c, nsigma, zlev): zlev : xr.DataArray Coordinate defined only for `nlayer - nsigma` where `nlayer` is the size of the vertical coordinate. - Returns - ------- - xr.DataArray - A DataArray with the height (positive upwards) relative to the ocean datum. - Notes ----- The description of this type of parametric vertical coordinate is defective in version 1.8 and earlier versions of the standard, in that it does not state what values the vertical coordinate variable should contain. Therefore, in accordance with the rules, all versions of the standard before 1.9 are deprecated for datasets that use the "ocean sigma over z" coordinate. @@ -611,34 +693,61 @@ def ocean_sigma_z_coordinate(sigma, eta, depth, depth_c, nsigma, zlev): Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_sigma_over_z_coordinate """ - z_shape = list(eta.shape) - z_shape.insert(1, sigma.shape[0]) + def __init__(self, sigma, eta, depth, depth_c, nsigma, zlev): + self.sigma = sigma + self.eta = eta + self.depth = depth + self.depth_c = depth_c + self.nsigma = nsigma + self.zlev = zlev + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + z_shape = list(self.eta.shape) + + z_shape.insert(1, self.sigma.shape[0]) - z_dims = list(eta.dims) + z_dims = list(self.eta.dims) - z_dims.insert(1, sigma.dims[0]) + z_dims.insert(1, self.sigma.dims[0]) - z = xr.DataArray(np.empty(z_shape), dims=z_dims) + z = xr.DataArray(np.empty(z_shape), dims=z_dims) - z_sigma = eta + sigma * (np.minimum(depth_c, depth) + eta) + z_sigma = self.eta + self.sigma * ( + np.minimum(self.depth_c, self.depth) + self.eta + ) - z = xr.where(~np.isnan(sigma), z_sigma, z) + z = xr.where(~np.isnan(self.sigma), z_sigma, z) - z = xr.where(np.isnan(sigma), zlev, z) + z = xr.where(np.isnan(self.sigma), self.zlev, z) - out_stdname = _derive_ocean_stdname( - eta=eta.attrs, depth=depth.attrs, zlev=zlev.attrs - ) + return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) - z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + return _derive_ocean_stdname( + eta=self.eta.attrs, depth=self.depth.attrs, zlev=self.zlev.attrs + ) - output_order = derive_dimension_order("nkji", nji=eta, k=sigma) + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + sigma, eta, depth, depth_c, nsigma, zlev = get_terms( + terms, "sigma", "eta", "depth", "depth_c", "nsigma", "zlev" + ) - return z.transpose(*output_order) + return cls(sigma, eta, depth, depth_c, nsigma, zlev) -def ocean_double_sigma_coordinate(sigma, depth, z1, z2, a, href, k_c): +class OceanDoubleSigma(ParamerticVerticalCoordinate): """Ocean double sigma coordinate. Standard name: ocean_double_sigma_coordinate @@ -659,33 +768,95 @@ def ocean_double_sigma_coordinate(sigma, depth, z1, z2, a, href, k_c): Constant with units of length. k_c : xr.DataArray - Returns - ------- - xr.DataArray - A DataArray with the height (positive upwards) relative to the datum. - References ---------- Please refer to the CF conventions document : 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_double_sigma_coordinate """ - f = 0.5 * (z1 + z2) + 0.5 * (z1 - z2) * np.tanh(2 * a / (z1 - z2) * (depth - href)) - # shape k, j, i - z_shape = sigma.shape + depth.shape + def __init__(self, sigma, depth, z1, z2, a, href, k_c): + self.sigma = sigma + self.depth = depth + self.z1 = z1 + self.z2 = z2 + self.a = a + self.href = href + self.k_c = k_c + + def decode(self) -> xr.DataArray: + """Decode coordinate. + + Returns + ------- + xr.DataArray + Decoded parametric vertical coordinate. + """ + f = 0.5 * (self.z1 + self.z2) + 0.5 * (self.z1 - self.z2) * np.tanh( + 2 * self.a / (self.z1 - self.z2) * (self.depth - self.href) + ) + + # shape k, j, i + z_shape = self.sigma.shape + self.depth.shape + + z_dims = self.sigma.dims + self.depth.dims + + z = xr.DataArray(np.empty(z_shape), dims=z_dims, name="z") + + z = xr.where(self.sigma.k <= self.k_c, self.sigma * f, z) + + z = xr.where( + self.sigma.k > self.k_c, f + (self.sigma - 1) * (self.depth - f), z + ) + + return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + + @property + def computed_standard_name(self) -> str: + """Computes coordinate standard name.""" + return _derive_ocean_stdname(depth=self.depth.attrs) + + @classmethod + def from_terms(cls, terms: dict): + """Create coordinate from terms.""" + sigma, depth, z1, z2, a, href, k_c = get_terms( + terms, "sigma", "depth", "z1", "z2", "a", "href", "k_c" + ) - z_dims = sigma.dims + depth.dims + return cls(sigma, depth, z1, z2, a, href, k_c) + + +TRANSFORM_FROM_STDNAME = { + "atmosphere_ln_pressure_coordinate": AtmosphereLnPressure, + "atmosphere_sigma_coordinate": AtmosphereSigma, + "atmosphere_hybrid_sigma_pressure_coordinate": AtmosphereHybridSigmaPressure, + "atmosphere_hybrid_height_coordinate": AtmosphereHybridHeight, + "atmosphere_sleve_coordinate": AtmosphereSleve, + "ocean_sigma_coordinate": OceanSigma, + "ocean_s_coordinate": OceanS, + "ocean_s_coordinate_g1": OceanSG1, + "ocean_s_coordinate_g2": OceanSG2, + "ocean_sigma_z_coordinate": OceanSigmaZ, + "ocean_double_sigma_coordinate": OceanDoubleSigma, +} - z = xr.DataArray(np.empty(z_shape), dims=z_dims, name="z") - z = xr.where(sigma.k <= k_c, sigma * f, z) +def get_terms( + terms: dict[str, DataArray], *required, optional: Sequence[str] = None +) -> DataArray: + if optional is None: + optional = [] - z = xr.where(sigma.k > k_c, f + (sigma - 1) * (depth - f), z) + selected_terms = [] - out_stdname = _derive_ocean_stdname(depth=depth.attrs) + for term in required + tuple(optional): + da = None - z = z.squeeze().rename("z").assign_attrs(standard_name=out_stdname) + try: + da = terms[term] + except KeyError: + if term not in optional: + raise KeyError(f"Required term {term} is absent in dataset.") from None - output_order = derive_dimension_order("kji", ji=depth, k=sigma) + selected_terms.append(da) - return z.transpose(*output_order) + return selected_terms diff --git a/cf_xarray/tests/test_parametric.py b/cf_xarray/tests/test_parametric.py index 392e4afd..03615ee0 100644 --- a/cf_xarray/tests/test_parametric.py +++ b/cf_xarray/tests/test_parametric.py @@ -41,21 +41,6 @@ s = xr.DataArray([0, 1, 2], dims=("lev"), name="s") -@pytest.mark.parametrize( - "order,kwargs,expected", - [ - ("nkji", {"nji": ps, "k": a}, ("time", "lev", "lat", "lon")), - ("nkij", {"nji": ps, "k": a}, ("time", "lev", "lon", "lat")), - ("ijn", {"nji": ps}, ("lon", "lat", "time")), - ("njk", {"nji": ps, "k": a}, ("time", "lat", "lev")), - ], -) -def test_derive_dimension_order(order, kwargs, expected): - order = parametric.derive_dimension_order(order, **kwargs) - - assert order == expected - - def test_atmosphere_ln_pressure_coordinate(): lev = xr.DataArray( [0, 1, 2], @@ -63,94 +48,128 @@ def test_atmosphere_ln_pressure_coordinate(): name="lev", ) - output = parametric.atmosphere_ln_pressure_coordinate(p0, lev) + transform = parametric.AtmosphereLnPressure.from_terms( + { + "p0": p0, + "lev": lev, + } + ) + + output = transform.decode() expected = xr.DataArray([10.0, 3.678794, 1.353353], dims=("lev",), name="p") assert_allclose(output, expected) - assert output.name == "p" assert output.attrs["standard_name"] == "air_pressure" def test_atmosphere_sigma_coordinate(): ptop = xr.DataArray([0.98692327], name="ptop") - output = parametric.atmosphere_sigma_coordinate(sigma, ps, ptop) + transform = parametric.AtmosphereSigma.from_terms( + { + "sigma": sigma, + "ps": ps, + "ptop": ptop, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ [[0.986923, 0.986923], [0.986923, 0.986923]], - [[1.0, 1.0], [1.0, 1.0]], - [[1.013077, 1.013077], [1.013077, 1.013077]], + [[0.986923, 0.986923], [0.986923, 0.986923]], ], [ - [[0.986923, 0.986923], [0.986923, 0.986923]], [[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]], + ], + [ + [[1.013077, 1.013077], [1.013077, 1.013077]], [[1.013077, 1.013077], [1.013077, 1.013077]], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("lev", "time", "lat", "lon"), coords={"k": (("lev",), [0, 1, 2])}, name="p", ) assert_allclose(output, expected) - assert output.name == "p" assert output.attrs["standard_name"] == "air_pressure" def test_atmosphere_hybrid_sigma_pressure_coordinate(): ap = xr.DataArray([3, 4, 5], dims=("lev",), name="ap") - output = parametric.atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, a=a) + transform = parametric.AtmosphereHybridSigmaPressure.from_terms( + { + "b": b, + "ps": ps, + "a": a, + "p0": p0, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ [[6.0, 6.0], [6.0, 6.0]], - [[17.0, 17.0], [17.0, 17.0]], - [[28.0, 28.0], [28.0, 28.0]], + [[6.0, 6.0], [6.0, 6.0]], ], [ - [[6.0, 6.0], [6.0, 6.0]], [[17.0, 17.0], [17.0, 17.0]], + [[17.0, 17.0], [17.0, 17.0]], + ], + [ + [[28.0, 28.0], [28.0, 28.0]], [[28.0, 28.0], [28.0, 28.0]], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("lev", "time", "lat", "lon"), name="p", ) assert_allclose(output, expected) - assert output.name == "p" assert output.attrs["standard_name"] == "air_pressure" - output = parametric.atmosphere_hybrid_sigma_pressure_coordinate(b, ps, p0, ap=ap) + transform = parametric.AtmosphereHybridSigmaPressure.from_terms( + { + "b": b, + "ps": ps, + "ap": ap, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ [[9.0, 9.0], [9.0, 9.0]], - [[11.0, 11.0], [11.0, 11.0]], - [[13.0, 13.0], [13.0, 13.0]], + [[9.0, 9.0], [9.0, 9.0]], ], [ - [[9.0, 9.0], [9.0, 9.0]], [[11.0, 11.0], [11.0, 11.0]], + [[11.0, 11.0], [11.0, 11.0]], + ], + [ + [[13.0, 13.0], [13.0, 13.0]], [[13.0, 13.0], [13.0, 13.0]], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("lev", "time", "lat", "lon"), name="p", ) assert_allclose(output, expected) - assert output.name == "p" assert output.attrs["standard_name"] == "air_pressure" @@ -161,28 +180,37 @@ def test_atmosphere_hybrid_height_coordinate(): attrs={"standard_name": "surface_altitude"}, ) - output = parametric.atmosphere_hybrid_height_coordinate(a, b, orog) + transform = parametric.AtmosphereHybridHeight.from_terms( + { + "a": a, + "b": b, + "orog": orog, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ [[0.0, 0.0], [0.0, 0.0]], - [[1.0, 1.0], [1.0, 1.0]], - [[2.0, 2.0], [2.0, 2.0]], + [[0.0, 0.0], [0.0, 0.0]], ], [ - [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]], + ], + [ + [[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("lev", "time", "lat", "lon"), name="p", ) assert_allclose(output, expected) - assert output.name == "z" assert output.attrs["standard_name"] == "altitude" @@ -201,55 +229,84 @@ def test_atmosphere_sleve_coordinate(): zsurf2 = xr.DataArray(np.ones((2, 2, 2)), dims=("time", "lat", "lon")) - output = parametric.atmosphere_sleve_coordinate(a, b1, b2, ztop, zsurf1, zsurf2) + transform = parametric.AtmosphereSleve.from_terms( + { + "a": a, + "b1": b1, + "b2": b2, + "ztop": ztop, + "zsurf1": zsurf1, + "zsurf2": zsurf2, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ [[1.0, 1.0], [1.0, 1.0]], - [[31.0, 31.0], [31.0, 31.0]], - [[61.0, 61.0], [61.0, 61.0]], + [[1.0, 1.0], [1.0, 1.0]], ], [ - [[1.0, 1.0], [1.0, 1.0]], [[31.0, 31.0], [31.0, 31.0]], + [[31.0, 31.0], [31.0, 31.0]], + ], + [ + [[61.0, 61.0], [61.0, 61.0]], [[61.0, 61.0], [61.0, 61.0]], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("lev", "time", "lat", "lon"), name="z", ) assert_allclose(output, expected) - assert output.name == "z" assert output.attrs["standard_name"] == "altitude" def test_ocean_sigma_coordinate(): - output = parametric.ocean_sigma_coordinate(sigma, eta, depth) + transform = parametric.OceanSigma.from_terms( + { + "sigma": sigma, + "eta": eta, + "depth": depth, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ - [[1.0, 1.0], [1.0, 1.0]], - [[3.0, 3.0], [3.0, 3.0]], - [[5.0, 5.0], [5.0, 5.0]], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], ], [ - [[1.0, 1.0], [1.0, 1.0]], - [[3.0, 3.0], [3.0, 3.0]], - [[5.0, 5.0], [5.0, 5.0]], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("time", "lat", "lon", "lev"), name="z", coords={"k": (("lev",), [0, 1, 2])}, ) assert_allclose(output, expected) - assert output.name == "z" assert output.attrs["standard_name"] == "altitude" @@ -258,84 +315,142 @@ def test_ocean_s_coordinate(): _b = xr.DataArray([1], name="b") - output = parametric.ocean_s_coordinate(s, eta, depth, _a, _b, depth_c) + transform = parametric.OceanS.from_terms( + { + "s": s, + "eta": eta, + "depth": depth, + "a": _a, + "b": _b, + "depth_c": depth_c, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ - [[12.403492, 12.403492], [12.403492, 12.403492]], - [[40.434874, 40.434874], [40.434874, 40.434874]], - [[70.888995, 70.888995], [70.888995, 70.888995]], + [ + [12.403492, 40.434874, 70.888995], + [12.403492, 40.434874, 70.888995], + ], + [ + [12.403492, 40.434874, 70.888995], + [12.403492, 40.434874, 70.888995], + ], ], [ - [[12.403492, 12.403492], [12.403492, 12.403492]], - [[40.434874, 40.434874], [40.434874, 40.434874]], - [[70.888995, 70.888995], [70.888995, 70.888995]], + [ + [12.403492, 40.434874, 70.888995], + [12.403492, 40.434874, 70.888995], + ], + [ + [12.403492, 40.434874, 70.888995], + [12.403492, 40.434874, 70.888995], + ], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("time", "lat", "lon", "lev"), name="z", ) assert_allclose(output, expected) - assert output.name == "z" assert output.attrs["standard_name"] == "altitude" def test_ocean_s_coordinate_g1(): C = xr.DataArray([0, 1, 2], dims=("lev",), name="C") - output = parametric.ocean_s_coordinate_g1(s, C, eta, depth, depth_c) + transform = parametric.OceanSG2.from_terms( + { + "s": s, + "c": C, + "eta": eta, + "depth": depth, + "depth_c": depth_c, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ - [[1.0, 1.0], [1.0, 1.0]], - [[3.0, 3.0], [3.0, 3.0]], - [[5.0, 5.0], [5.0, 5.0]], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], ], [ - [[1.0, 1.0], [1.0, 1.0]], - [[3.0, 3.0], [3.0, 3.0]], - [[5.0, 5.0], [5.0, 5.0]], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("time", "lat", "lon", "lev"), name="z", ) assert_allclose(output, expected) - assert output.name == "z" assert output.attrs["standard_name"] == "altitude" def test_ocean_s_coordinate_g2(): C = xr.DataArray([0, 1, 2], dims=("lev",), name="C") - output = parametric.ocean_s_coordinate_g2(s, C, eta, depth, depth_c) + transform = parametric.OceanSG2.from_terms( + { + "s": s, + "c": C, + "eta": eta, + "depth": depth, + "depth_c": depth_c, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ - [[1.0, 1.0], [1.0, 1.0]], - [[3.0, 3.0], [3.0, 3.0]], - [[5.0, 5.0], [5.0, 5.0]], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], ], [ - [[1.0, 1.0], [1.0, 1.0]], - [[3.0, 3.0], [3.0, 3.0]], - [[5.0, 5.0], [5.0, 5.0]], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], + [ + [1.0, 3.0, 5.0], + [1.0, 3.0, 5.0], + ], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("time", "lat", "lon", "lev"), name="z", ) assert_allclose(output, expected) - assert output.name == "z" assert output.attrs["standard_name"] == "altitude" @@ -346,28 +461,40 @@ def test_ocean_sigma_z_coordinate(): _sigma = xr.DataArray([np.nan, np.nan, 3], dims=("lev",), name="sigma") - output = parametric.ocean_sigma_z_coordinate(_sigma, eta, depth, depth_c, 10, zlev) + transform = parametric.OceanSigmaZ.from_terms( + { + "sigma": _sigma, + "eta": eta, + "depth": depth, + "depth_c": depth_c, + "nsigma": 10, + "zlev": zlev, + } + ) + + output = transform.decode() expected = xr.DataArray( [ [ [[0.0, 0.0], [0.0, 0.0]], - [[1.0, 1.0], [1.0, 1.0]], - [[7.0, 7.0], [7.0, 7.0]], + [[0.0, 0.0], [0.0, 0.0]], ], [ - [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]], + ], + [ + [[7.0, 7.0], [7.0, 7.0]], [[7.0, 7.0], [7.0, 7.0]], ], ], - dims=("time", "lev", "lat", "lon"), + dims=("lev", "time", "lat", "lon"), name="z", ) assert_allclose(output, expected) - assert output.name == "z" assert output.attrs["standard_name"] == "altitude" @@ -407,10 +534,20 @@ def test_ocean_double_sigma_coordinate(): name="a", ) - output = parametric.ocean_double_sigma_coordinate( - sigma, depth, z1, z2, a, href, k_c + transform = parametric.OceanDoubleSigma.from_terms( + { + "sigma": sigma, + "depth": depth, + "z1": z1, + "z2": z2, + "a": a, + "href": href, + "k_c": k_c, + } ) + output = transform.decode() + expected = xr.DataArray( [ [[0.0, 0.0], [0.0, 0.0]], @@ -424,7 +561,6 @@ def test_ocean_double_sigma_coordinate(): assert_allclose(output, expected) - assert output.name == "z" assert output.attrs["standard_name"] == "altitude" @@ -487,55 +623,3 @@ def test_derive_ocean_stdname_no_match(): match="Could not derive standard name from combination of not in any list.", ): parametric._derive_ocean_stdname(zlev={"standard_name": "not in any list"}) - - -def test_func_from_stdname(): - with pytest.raises(AttributeError): - parametric.func_from_stdname("test") - - func = parametric.func_from_stdname("atmosphere_ln_pressure_coordinate") - - assert func == parametric.atmosphere_ln_pressure_coordinate - - -def test_check_requirements(): - with pytest.raises(KeyError, match="'Required terms lev, p0 absent in dataset.'"): - parametric.check_requirements(parametric.atmosphere_ln_pressure_coordinate, []) - - parametric.check_requirements( - parametric.atmosphere_ln_pressure_coordinate, ["p0", "lev"] - ) - - with pytest.raises( - KeyError, - match=r"'Required terms b, p0 are absent in the dataset.'", - ): - parametric.check_requirements( - parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps"] - ) - - with pytest.raises( - KeyError, - match="'Atleast one optional term a, ap is absent in the dataset.'", - ): - parametric.check_requirements( - parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "p0", "b"] - ) - - with pytest.raises( - KeyError, - match="'Required terms b are absent in the dataset.'", - ): - parametric.check_requirements( - parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "p0", "a"] - ) - - # Should pass - parametric.check_requirements( - parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "p0", "b", "a"] - ) - - # check case insensitive - parametric.check_requirements( - parametric.atmosphere_hybrid_sigma_pressure_coordinate, ["ps", "P0", "b", "A"] - ) From c76828b9f6ba179b361c0e1ab727b88c6ca68b12 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Tue, 20 Aug 2024 13:35:30 -0700 Subject: [PATCH 07/22] Fixes failing test --- cf_xarray/parametric.py | 2 +- cf_xarray/tests/test_accessor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 95aacf0d..429b593d 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -855,7 +855,7 @@ def get_terms( da = terms[term] except KeyError: if term not in optional: - raise KeyError(f"Required term {term} is absent in dataset.") from None + raise KeyError(f"Required term {term} is absent in the dataset.") from None selected_terms.append(da) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 12d84218..3bcb608b 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -1319,7 +1319,7 @@ def test_decode_vertical_coords() -> None: romsds_less_h = romsds.drop_vars(["h"]) with pytest.raises( - KeyError, match="Required terms depth are absent in the dataset." + KeyError, match="Required term depth is absent in the dataset." ): romsds_less_h.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"}) From 2e02ef9894dc18adb5d346d24de8fd98a7621db4 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Tue, 20 Aug 2024 15:56:20 -0700 Subject: [PATCH 08/22] Removes extras spaces in text --- cf_xarray/parametric.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 429b593d..5766830e 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -13,18 +13,18 @@ }, "height_above_geopotential_datum": { "zlev": "height_above_geopotential_datum", - "eta": "sea_surface_height_above_ geopotential_datum", - "depth": "sea_floor_depth_below_ geopotential_datum", + "eta": "sea_surface_height_above_geopotential_datum", + "depth": "sea_floor_depth_below_geopotential_datum", }, "height_above_reference_ellipsoid": { "zlev": "height_above_reference_ellipsoid", - "eta": "sea_surface_height_above_ reference_ellipsoid", - "depth": "sea_floor_depth_below_ reference_ellipsoid", + "eta": "sea_surface_height_above_reference_ellipsoid", + "depth": "sea_floor_depth_below_reference_ellipsoid", }, "height_above_mean_sea_level": { "zlev": "height_above_mean_sea_level", - "eta": "sea_surface_height_above_mean_ sea_level", - "depth": "sea_floor_depth_below_mean_ sea_level", + "eta": "sea_surface_height_above_mean_sea_level", + "depth": "sea_floor_depth_below_mean_sea_level", }, } From 08170d1a5312a02c1fcb9a77e965fdd8977f3b95 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Tue, 20 Aug 2024 15:56:43 -0700 Subject: [PATCH 09/22] Fixes formatting --- cf_xarray/parametric.py | 4 +++- cf_xarray/tests/test_accessor.py | 4 +--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 5766830e..456977ff 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -855,7 +855,9 @@ def get_terms( da = terms[term] except KeyError: if term not in optional: - raise KeyError(f"Required term {term} is absent in the dataset.") from None + raise KeyError( + f"Required term {term} is absent in the dataset." + ) from None selected_terms.append(da) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 3bcb608b..d40289ad 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -1318,9 +1318,7 @@ def test_decode_vertical_coords() -> None: romsds_less_h = romsds.drop_vars(["h"]) - with pytest.raises( - KeyError, match="Required term depth is absent in the dataset." - ): + with pytest.raises(KeyError, match="Required term depth is absent in the dataset."): romsds_less_h.cf.decode_vertical_coords(outnames={"s_rho": "z_rho"}) From 80d12147c17c091fd4ca8e4e4e3ceaba3a9cd85b Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Tue, 20 Aug 2024 16:36:42 -0700 Subject: [PATCH 10/22] Resolves mypy errors --- cf_xarray/parametric.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 456977ff..c9240dc7 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -278,17 +278,17 @@ def from_terms(cls, terms: dict): """Create coordinate from terms.""" b, ps, p0, a, ap = get_terms(terms, "b", "ps", optional=("p0", "a", "ap")) - if a is None and ap is None: + if a is None and ap is None: # type: ignore[unreachable] raise KeyError( "Optional terms 'a', 'ap' are absent in the dataset, atleast one must be present." ) - if a is not None and ap is not None: + if a is not None and ap is not None: # type: ignore[redundant-expr] raise Exception( "Both optional terms 'a' and 'ap' are present in the dataset, please drop one of them." ) - if a is not None and p0 is None: + if a is not None and p0 is None: # type: ignore[unreachable] raise KeyError( "Optional term 'a' is present but 'p0' is absent in the dataset." ) @@ -841,8 +841,8 @@ def from_terms(cls, terms: dict): def get_terms( - terms: dict[str, DataArray], *required, optional: Sequence[str] = None -) -> DataArray: + terms: dict[str, DataArray], *required, optional: Sequence[str] | None = None +) -> list[DataArray]: if optional is None: optional = [] @@ -861,4 +861,4 @@ def get_terms( selected_terms.append(da) - return selected_terms + return selected_terms # type: ignore[return-value] From 381ea30bb59a118817062cce6ca87db5829343fd Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 21 Aug 2024 00:12:20 -0700 Subject: [PATCH 11/22] Removes redundant code --- cf_xarray/parametric.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index c9240dc7..bdf0f18a 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -710,23 +710,11 @@ def decode(self) -> xr.DataArray: xr.DataArray Decoded parametric vertical coordinate. """ - z_shape = list(self.eta.shape) - - z_shape.insert(1, self.sigma.shape[0]) - - z_dims = list(self.eta.dims) - - z_dims.insert(1, self.sigma.dims[0]) - - z = xr.DataArray(np.empty(z_shape), dims=z_dims) - z_sigma = self.eta + self.sigma * ( np.minimum(self.depth_c, self.depth) + self.eta ) - z = xr.where(~np.isnan(self.sigma), z_sigma, z) - - z = xr.where(np.isnan(self.sigma), self.zlev, z) + z = xr.where(np.isnan(self.sigma), self.zlev, z_sigma) return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) From 54f64929e1a63f7d87a4c5160517d0c820c68622 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 21 Aug 2024 00:13:21 -0700 Subject: [PATCH 12/22] Fixes global variable case --- cf_xarray/parametric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index bdf0f18a..26f01ab8 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -5,7 +5,7 @@ import xarray as xr from xarray import DataArray -ocean_stdname_map = { +OCEAN_STDNAME_MAP = { "altitude": { "zlev": "altitude", "eta": "sea_surface_height_above_geoid", From 32eae7e9f9b649b017af3a1b711b4662c159f842 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 21 Aug 2024 00:50:30 -0700 Subject: [PATCH 13/22] Updates _derive_ocean_stdname --- cf_xarray/parametric.py | 28 ++++++++++++++-------------- cf_xarray/tests/test_parametric.py | 12 ------------ 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 26f01ab8..60992cac 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -29,7 +29,7 @@ } -def _derive_ocean_stdname(**kwargs): +def _derive_ocean_stdname(*, zlev=None, eta=None, depth=None): """Derive standard name for computer ocean coordinates. Uses the concatentation of formula terms `zlev`, `eta`, and `depth` @@ -63,33 +63,29 @@ def _derive_ocean_stdname(**kwargs): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#table-computed-standard-names """ found_stdname = None - - allowed_names = {"zlev", "eta", "depth"} - - if len(kwargs) == 0 or not (set(kwargs) <= allowed_names): - raise ValueError( - f"Must provide atleast one of {', '.join(sorted(allowed_names))}." - ) - search_term = "" + search_vars = {"zlev": zlev, "eta": eta, "depth": depth} + + for x, y in sorted(search_vars.items(), key=lambda x: x[0]): + if y is None: + continue - for x, y in sorted(kwargs.items(), key=lambda x: x[0]): try: search_term = f"{search_term}{y['standard_name']}" except TypeError: raise ValueError( - f"The values for {', '.join(sorted(kwargs.keys()))} cannot be `None`." + f"The values for {', '.join(sorted(search_vars.keys()))} cannot be `None`." ) from None except KeyError: raise ValueError( f"The standard name for the {x!r} variable is not available." ) from None - for x, y in ocean_stdname_map.items(): + for x, y in OCEAN_STDNAME_MAP.items(): check_term = "".join( [ y[i] - for i, j in sorted(kwargs.items(), key=lambda x: x[0]) + for i, j in sorted(search_vars.items(), key=lambda x: x[0]) if j is not None ] ) @@ -101,7 +97,11 @@ def _derive_ocean_stdname(**kwargs): if found_stdname is None: stdnames = ", ".join( - [y["standard_name"] for _, y in sorted(kwargs.items(), key=lambda x: x[0])] + [ + y["standard_name"] + for _, y in sorted(search_vars.items(), key=lambda x: x[0]) + if y is not None + ] ) raise ValueError( diff --git a/cf_xarray/tests/test_parametric.py b/cf_xarray/tests/test_parametric.py index 03615ee0..c64e7e3d 100644 --- a/cf_xarray/tests/test_parametric.py +++ b/cf_xarray/tests/test_parametric.py @@ -598,18 +598,6 @@ def test_derive_ocean_stdname(input, expected): assert output == expected -def test_derive_ocean_stdname_no_values(): - with pytest.raises( - ValueError, match="Must provide atleast one of depth, eta, zlev." - ): - parametric._derive_ocean_stdname() - - -def test_derive_ocean_stdname_empty_value(): - with pytest.raises(ValueError, match="The values for zlev cannot be `None`."): - parametric._derive_ocean_stdname(zlev=None) - - def test_derive_ocean_stdname_no_standard_name(): with pytest.raises( ValueError, match="The standard name for the 'zlev' variable is not available." From 9d22976a46f99213da4b0c2fdf6ce5eeabc8b62d Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 21 Aug 2024 01:08:43 -0700 Subject: [PATCH 14/22] Moves to dataclass --- cf_xarray/parametric.py | 158 ++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 78 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 60992cac..7b1a184c 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from dataclasses import dataclass import numpy as np import xarray as xr @@ -127,6 +128,7 @@ def computed_standard_name(self): pass +@dataclass class AtmosphereLnPressure(ParamerticVerticalCoordinate): """Atmosphere natural log pressure coordinate. @@ -145,9 +147,8 @@ class AtmosphereLnPressure(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#atmosphere-natural-log-pressure-coordinate """ - def __init__(self, p0, lev): - self.p0 = p0 - self.lev = lev + p0: DataArray + lev: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. @@ -174,6 +175,7 @@ def from_terms(cls, terms: dict): return cls(p0, lev) +@dataclass class AtmosphereSigma(ParamerticVerticalCoordinate): """Atmosphere sigma coordinate. @@ -192,10 +194,9 @@ class AtmosphereSigma(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_atmosphere_sigma_coordinate """ - def __init__(self, sigma, ps, ptop): - self.sigma = sigma - self.ps = ps - self.ptop = ptop + sigma: DataArray + ps: DataArray + ptop: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. @@ -222,6 +223,7 @@ def from_terms(cls, terms: dict): return cls(sigma, ps, ptop) +@dataclass class AtmosphereHybridSigmaPressure(ParamerticVerticalCoordinate): """Atmosphere hybrid sigma pressure coordinate. @@ -246,12 +248,27 @@ class AtmosphereHybridSigmaPressure(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_atmosphere_hybrid_sigma_pressure_coordinate """ - def __init__(self, b, ps, p0=None, a=None, ap=None): - self.b = b - self.ps = ps - self.p0 = p0 - self.a = a - self.ap = ap + b: DataArray + ps: DataArray + p0: DataArray + a: DataArray + ap: DataArray + + def __post_init__(self): + if self.a is None and self.ap is None: + raise KeyError( + "Optional terms 'a', 'ap' are absent in the dataset, atleast one must be present." + ) + + if self.a is not None and self.ap is not None: + raise Exception( + "Both optional terms 'a' and 'ap' are present in the dataset, please drop one of them." + ) + + if self.a is not None and self.p0 is None: + raise KeyError( + "Optional term 'a' is present but 'p0' is absent in the dataset." + ) def decode(self) -> xr.DataArray: """Decode coordinate. @@ -262,7 +279,7 @@ def decode(self) -> xr.DataArray: Decoded parametric vertical coordinate. """ if self.a is None: - p = self.ap + self.b * self.ps + p = self.ap + self.b * self.ps # type: ignore[unreachable] else: p = self.a * self.p0 + self.b * self.ps @@ -278,24 +295,10 @@ def from_terms(cls, terms: dict): """Create coordinate from terms.""" b, ps, p0, a, ap = get_terms(terms, "b", "ps", optional=("p0", "a", "ap")) - if a is None and ap is None: # type: ignore[unreachable] - raise KeyError( - "Optional terms 'a', 'ap' are absent in the dataset, atleast one must be present." - ) - - if a is not None and ap is not None: # type: ignore[redundant-expr] - raise Exception( - "Both optional terms 'a' and 'ap' are present in the dataset, please drop one of them." - ) - - if a is not None and p0 is None: # type: ignore[unreachable] - raise KeyError( - "Optional term 'a' is present but 'p0' is absent in the dataset." - ) - return cls(b, ps, p0, a, ap) +@dataclass class AtmosphereHybridHeight(ParamerticVerticalCoordinate): """Atmosphere hybrid height coordinate. @@ -316,10 +319,9 @@ class AtmosphereHybridHeight(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#atmosphere-hybrid-height-coordinate """ - def __init__(self, a, b, orog): - self.a = a - self.b = b - self.orog = orog + a: DataArray + b: DataArray + orog: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. @@ -353,6 +355,7 @@ def from_terms(cls, terms: dict): return cls(a, b, orog) +@dataclass class AtmosphereSleve(ParamerticVerticalCoordinate): """Atmosphere smooth level vertical (SLEVE) coordinate. @@ -379,13 +382,12 @@ class AtmosphereSleve(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_atmosphere_smooth_level_vertical_sleve_coordinate """ - def __init__(self, a, b1, b2, ztop, zsurf1, zsurf2): - self.a = a - self.b1 = b1 - self.b2 = b2 - self.ztop = ztop - self.zsurf1 = zsurf1 - self.zsurf2 = zsurf2 + a: DataArray + b1: DataArray + b2: DataArray + ztop: DataArray + zsurf1: DataArray + zsurf2: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. @@ -423,6 +425,7 @@ def from_terms(cls, terms: dict): return cls(a, b1, b2, ztop, zsurf1, zsurf2) +@dataclass class OceanSigma(ParamerticVerticalCoordinate): """Ocean sigma coordinate. @@ -443,10 +446,9 @@ class OceanSigma(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_sigma_coordinate """ - def __init__(self, sigma, eta, depth): - self.sigma = sigma - self.eta = eta - self.depth = depth + sigma: DataArray + eta: DataArray + depth: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. @@ -475,6 +477,7 @@ def from_terms(cls, terms: dict): return cls(sigma, eta, depth) +@dataclass class OceanS(ParamerticVerticalCoordinate): """Ocean s-coordinate. @@ -501,13 +504,12 @@ class OceanS(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate """ - def __init__(self, s, eta, depth, a, b, depth_c): - self.s = s - self.eta = eta - self.depth = depth - self.a = a - self.b = b - self.depth_c = depth_c + s: DataArray + eta: DataArray + depth: DataArray + a: DataArray + b: DataArray + depth_c: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. @@ -544,6 +546,7 @@ def from_terms(cls, terms: dict): return cls(s, eta, depth, a, b, depth_c) +@dataclass class OceanSG1(ParamerticVerticalCoordinate): """Ocean s-coordinate, generic form 1. @@ -568,12 +571,11 @@ class OceanSG1(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_1 """ - def __init__(self, s, c, eta, depth, depth_c): - self.s = s - self.c = c - self.eta = eta - self.depth = depth - self.depth_c = depth_c + s: DataArray + c: DataArray + eta: DataArray + depth: DataArray + depth_c: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. @@ -604,6 +606,7 @@ def from_terms(cls, terms: dict): return cls(s, c, eta, depth, depth_c) +@dataclass class OceanSG2(ParamerticVerticalCoordinate): """Ocean s-coordinate, generic form 2. @@ -628,12 +631,11 @@ class OceanSG2(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_s_coordinate_generic_form_2 """ - def __init__(self, s, c, eta, depth, depth_c): - self.s = s - self.c = c - self.eta = eta - self.depth = depth - self.depth_c = depth_c + s: DataArray + c: DataArray + eta: DataArray + depth: DataArray + depth_c: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. @@ -664,6 +666,7 @@ def from_terms(cls, terms: dict): return cls(s, c, eta, depth, depth_c) +@dataclass class OceanSigmaZ(ParamerticVerticalCoordinate): """Ocean sigma over z coordinate. @@ -694,13 +697,12 @@ class OceanSigmaZ(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_sigma_over_z_coordinate """ - def __init__(self, sigma, eta, depth, depth_c, nsigma, zlev): - self.sigma = sigma - self.eta = eta - self.depth = depth - self.depth_c = depth_c - self.nsigma = nsigma - self.zlev = zlev + sigma: DataArray + eta: DataArray + depth: DataArray + depth_c: DataArray + nsigma: DataArray + zlev: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. @@ -735,6 +737,7 @@ def from_terms(cls, terms: dict): return cls(sigma, eta, depth, depth_c, nsigma, zlev) +@dataclass class OceanDoubleSigma(ParamerticVerticalCoordinate): """Ocean double sigma coordinate. @@ -762,14 +765,13 @@ class OceanDoubleSigma(ParamerticVerticalCoordinate): 1. https://cfconventions.org/cf-conventions/cf-conventions.html#_ocean_double_sigma_coordinate """ - def __init__(self, sigma, depth, z1, z2, a, href, k_c): - self.sigma = sigma - self.depth = depth - self.z1 = z1 - self.z2 = z2 - self.a = a - self.href = href - self.k_c = k_c + sigma: DataArray + depth: DataArray + z1: DataArray + z2: DataArray + a: DataArray + href: DataArray + k_c: DataArray def decode(self) -> xr.DataArray: """Decode coordinate. From 9396ef8e8bf8403cb8a60c6fb668e77e57ab6bc5 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 21 Aug 2024 01:20:30 -0700 Subject: [PATCH 15/22] Fixes passing variables to class constructor --- cf_xarray/parametric.py | 62 ++++++++++------------------------------- 1 file changed, 15 insertions(+), 47 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 7b1a184c..1a196e65 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -170,9 +170,7 @@ def computed_standard_name(self): @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - p0, lev = get_terms(terms, "p0", "lev") - - return cls(p0, lev) + return cls(**get_terms(terms, "p0", "lev")) @dataclass @@ -218,9 +216,7 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - sigma, ps, ptop = get_terms(terms, "sigma", "ps", "ptop") - - return cls(sigma, ps, ptop) + return cls(**get_terms(terms, "sigma", "ps", "ptop")) @dataclass @@ -293,9 +289,7 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - b, ps, p0, a, ap = get_terms(terms, "b", "ps", optional=("p0", "a", "ap")) - - return cls(b, ps, p0, a, ap) + return cls(**get_terms(terms, "b", "ps", optional=("p0", "a", "ap"))) @dataclass @@ -350,9 +344,7 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - a, b, orog = get_terms(terms, "a", "b", "orog") - - return cls(a, b, orog) + return cls(**get_terms(terms, "a", "b", "orog")) @dataclass @@ -418,11 +410,7 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - a, b1, b2, ztop, zsurf1, zsurf2 = get_terms( - terms, "a", "b1", "b2", "ztop", "zsurf1", "zsurf2" - ) - - return cls(a, b1, b2, ztop, zsurf1, zsurf2) + return cls(**get_terms(terms, "a", "b1", "b2", "ztop", "zsurf1", "zsurf2")) @dataclass @@ -472,9 +460,7 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - sigma, eta, depth = get_terms(terms, "sigma", "eta", "depth") - - return cls(sigma, eta, depth) + return cls(**get_terms(terms, "sigma", "eta", "depth")) @dataclass @@ -539,11 +525,7 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - s, eta, depth, a, b, depth_c = get_terms( - terms, "s", "eta", "depth", "a", "b", "depth_c" - ) - - return cls(s, eta, depth, a, b, depth_c) + return cls(**get_terms(terms, "s", "eta", "depth", "a", "b", "depth_c")) @dataclass @@ -599,11 +581,7 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - s, c, eta, depth, depth_c = get_terms( - terms, "s", "c", "eta", "depth", "depth_c" - ) - - return cls(s, c, eta, depth, depth_c) + return cls(**get_terms(terms, "s", "c", "eta", "depth", "depth_c")) @dataclass @@ -659,11 +637,7 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - s, c, eta, depth, depth_c = get_terms( - terms, "s", "c", "eta", "depth", "depth_c" - ) - - return cls(s, c, eta, depth, depth_c) + return cls(**get_terms(terms, "s", "c", "eta", "depth", "depth_c")) @dataclass @@ -730,12 +704,10 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - sigma, eta, depth, depth_c, nsigma, zlev = get_terms( - terms, "sigma", "eta", "depth", "depth_c", "nsigma", "zlev" + return cls( + **get_terms(terms, "sigma", "eta", "depth", "depth_c", "nsigma", "zlev") ) - return cls(sigma, eta, depth, depth_c, nsigma, zlev) - @dataclass class OceanDoubleSigma(ParamerticVerticalCoordinate): @@ -808,11 +780,7 @@ def computed_standard_name(self) -> str: @classmethod def from_terms(cls, terms: dict): """Create coordinate from terms.""" - sigma, depth, z1, z2, a, href, k_c = get_terms( - terms, "sigma", "depth", "z1", "z2", "a", "href", "k_c" - ) - - return cls(sigma, depth, z1, z2, a, href, k_c) + return cls(**get_terms(terms, "sigma", "depth", "z1", "z2", "a", "href", "k_c")) TRANSFORM_FROM_STDNAME = { @@ -832,11 +800,11 @@ def from_terms(cls, terms: dict): def get_terms( terms: dict[str, DataArray], *required, optional: Sequence[str] | None = None -) -> list[DataArray]: +) -> dict[str, DataArray]: if optional is None: optional = [] - selected_terms = [] + selected_terms = {} for term in required + tuple(optional): da = None @@ -849,6 +817,6 @@ def get_terms( f"Required term {term} is absent in the dataset." ) from None - selected_terms.append(da) + selected_terms[term] = da return selected_terms # type: ignore[return-value] From 9c16996ebb664f210d9ed0918aa8cee81c68fea9 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 21 Aug 2024 01:24:48 -0700 Subject: [PATCH 16/22] Removes redundant code --- cf_xarray/parametric.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index 1a196e65..b529f9f3 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -757,17 +757,10 @@ def decode(self) -> xr.DataArray: 2 * self.a / (self.z1 - self.z2) * (self.depth - self.href) ) - # shape k, j, i - z_shape = self.sigma.shape + self.depth.shape - - z_dims = self.sigma.dims + self.depth.dims - - z = xr.DataArray(np.empty(z_shape), dims=z_dims, name="z") - - z = xr.where(self.sigma.k <= self.k_c, self.sigma * f, z) - z = xr.where( - self.sigma.k > self.k_c, f + (self.sigma - 1) * (self.depth - f), z + self.sigma.k <= self.k_c, + self.sigma * f, + f + (self.sigma - 1) * (self.depth - f), ) return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) From cfd919e0fe0b380363fdc5a2ad14e13993f8a29d Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Wed, 21 Aug 2024 10:38:53 -0700 Subject: [PATCH 17/22] Fixes handling unknown standard names Co-authored-by: Deepak Cherian --- cf_xarray/parametric.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index b529f9f3..c012b94f 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -338,6 +338,8 @@ def computed_standard_name(self) -> str: out_stdname = "altitude" elif orog_stdname == "surface_height_above_geopotential_datum": out_stdname = "height_above_geopotential_datum" + else: + raise ValueError(f"Unknown standard name for hybrid height coordinate: {orog_stdname!r}") return out_stdname @@ -404,6 +406,8 @@ def computed_standard_name(self) -> str: ztop_stdname == "height_above_geopotential_datum_at_top_of_atmosphere_model" ): out_stdname = "height_above_geopotential_datum" + else: + raise ValueError(f"Unknown standard name: {out_stdname!r}") return out_stdname From 1c5d49f9ccff7f7e34008eeab6348c17f17e90a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 17:39:10 +0000 Subject: [PATCH 18/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cf_xarray/parametric.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index c012b94f..da11b4d1 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -339,7 +339,9 @@ def computed_standard_name(self) -> str: elif orog_stdname == "surface_height_above_geopotential_datum": out_stdname = "height_above_geopotential_datum" else: - raise ValueError(f"Unknown standard name for hybrid height coordinate: {orog_stdname!r}") + raise ValueError( + f"Unknown standard name for hybrid height coordinate: {orog_stdname!r}" + ) return out_stdname From a966919d3fe23756a35f34a3d8edf066685acdca Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Thu, 10 Oct 2024 18:55:42 -0700 Subject: [PATCH 19/22] Removes squeeze and fixes constants --- cf_xarray/parametric.py | 22 +++++++------- cf_xarray/tests/test_parametric.py | 48 ++++++++++++++++-------------- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index da11b4d1..b3184efd 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -160,7 +160,7 @@ def decode(self) -> xr.DataArray: """ p = self.p0 * np.exp(-self.lev) - return p.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return p.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self): @@ -206,7 +206,7 @@ def decode(self) -> xr.DataArray: """ p = self.ptop + self.sigma * (self.ps - self.ptop) - return p.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return p.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: @@ -279,7 +279,7 @@ def decode(self) -> xr.DataArray: else: p = self.a * self.p0 + self.b * self.ps - return p.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return p.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: @@ -327,7 +327,7 @@ def decode(self) -> xr.DataArray: """ z = self.a + self.b * self.orog - return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return z.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: @@ -395,7 +395,7 @@ def decode(self) -> xr.DataArray: """ z = self.a * self.ztop + self.b1 * self.zsurf1 + self.b2 * self.zsurf2 - return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return z.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: @@ -454,7 +454,7 @@ def decode(self) -> xr.DataArray: """ z = self.eta + self.sigma * (self.depth + self.eta) - return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return z.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: @@ -521,7 +521,7 @@ def decode(self) -> xr.DataArray: + (self.depth - self.depth_c) * C ) - return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return z.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: @@ -577,7 +577,7 @@ def decode(self) -> xr.DataArray: z = S + self.eta * (1 + self.s / self.depth) - return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return z.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: @@ -633,7 +633,7 @@ def decode(self) -> xr.DataArray: z = self.eta + (self.eta + self.depth) * S - return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return z.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: @@ -698,7 +698,7 @@ def decode(self) -> xr.DataArray: z = xr.where(np.isnan(self.sigma), self.zlev, z_sigma) - return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return z.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: @@ -769,7 +769,7 @@ def decode(self) -> xr.DataArray: f + (self.sigma - 1) * (self.depth - f), ) - return z.squeeze().assign_attrs(standard_name=self.computed_standard_name) + return z.assign_attrs(standard_name=self.computed_standard_name) @property def computed_standard_name(self) -> str: diff --git a/cf_xarray/tests/test_parametric.py b/cf_xarray/tests/test_parametric.py index c64e7e3d..9708b3be 100644 --- a/cf_xarray/tests/test_parametric.py +++ b/cf_xarray/tests/test_parametric.py @@ -8,9 +8,9 @@ ps = xr.DataArray(np.ones((2, 2, 2)), dims=("time", "lat", "lon"), name="ps") p0 = xr.DataArray( - [ - 10, - ], + 10.0, + dims=(), + coords={}, name="p0", ) @@ -36,7 +36,7 @@ attrs={"standard_name": "sea_floor_depth_below_geoid"}, ) -depth_c = xr.DataArray([30.0], name="depth_c") +depth_c = xr.DataArray(30.0, dims=(), coords={}, name="depth_c") s = xr.DataArray([0, 1, 2], dims=("lev"), name="s") @@ -65,7 +65,7 @@ def test_atmosphere_ln_pressure_coordinate(): def test_atmosphere_sigma_coordinate(): - ptop = xr.DataArray([0.98692327], name="ptop") + ptop = xr.DataArray(0.98692327, dims=(), coords={}, name="ptop") transform = parametric.AtmosphereSigma.from_terms( { @@ -220,7 +220,9 @@ def test_atmosphere_sleve_coordinate(): b2 = xr.DataArray([1, 1, 0], dims=("lev",), name="b2") ztop = xr.DataArray( - [30.0], + 30.0, + dims=(), + coords={}, name="ztop", attrs={"standard_name": "altitude_at_top_of_atmosphere_model"}, ) @@ -311,9 +313,9 @@ def test_ocean_sigma_coordinate(): def test_ocean_s_coordinate(): - _a = xr.DataArray([1], name="a") + _a = xr.DataArray(1, dims=(), coords={}, name="a") - _b = xr.DataArray([1], name="b") + _b = xr.DataArray(1, dims=(), coords={}, name="b") transform = parametric.OceanS.from_terms( { @@ -500,37 +502,37 @@ def test_ocean_sigma_z_coordinate(): def test_ocean_double_sigma_coordinate(): k_c = xr.DataArray( - [ - 1, - ], + 1, + dims=(), + coords={}, name="k_c", ) href = xr.DataArray( - [ - 20.0, - ], + 20.0, + dims=(), + coords={}, name="href", ) z1 = xr.DataArray( - [ - 10.0, - ], + 10.0, + dims=(), + coords={}, name="z1", ) z2 = xr.DataArray( - [ - 30.0, - ], + 30.0, + dims=(), + coords={}, name="z2", ) a = xr.DataArray( - [ - 2.0, - ], + 2.0, + dims=(), + coords={}, name="a", ) From ca33c67bb11a717301b38172aabe456e4998a9d1 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Thu, 10 Oct 2024 19:05:25 -0700 Subject: [PATCH 20/22] Adds entry to CITATIONS.cff --- CITATION.cff | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CITATION.cff b/CITATION.cff index 1ca4a1c2..f4a53fa4 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -83,6 +83,10 @@ authors: - family-names: Haëck given-names: Clément affiliation: Laboratoire d'Océanographie et du Climat (LOCEAN), Paris + - family-names: Boutte + given-names: Jason + orcid: 'https://orcid.org/0009-0009-3996-3772' + affiliation: Lawrence Livermore National Laboratory identifiers: - type: doi value: 10.5281/zenodo.4749735 From 978948fd8dc18d3e80cc7ba1d7c9cc1814bf7352 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 22 Oct 2024 15:56:32 -0600 Subject: [PATCH 21/22] cleanup --- cf_xarray/parametric.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cf_xarray/parametric.py b/cf_xarray/parametric.py index b3184efd..06b3d7fd 100644 --- a/cf_xarray/parametric.py +++ b/cf_xarray/parametric.py @@ -112,7 +112,7 @@ def _derive_ocean_stdname(*, zlev=None, eta=None, depth=None): return found_stdname -class ParamerticVerticalCoordinate(ABC): +class ParametricVerticalCoordinate(ABC): @classmethod @abstractmethod def from_terms(cls, terms: dict): @@ -129,7 +129,7 @@ def computed_standard_name(self): @dataclass -class AtmosphereLnPressure(ParamerticVerticalCoordinate): +class AtmosphereLnPressure(ParametricVerticalCoordinate): """Atmosphere natural log pressure coordinate. Standard name: atmosphere_ln_pressure_coordinate @@ -174,7 +174,7 @@ def from_terms(cls, terms: dict): @dataclass -class AtmosphereSigma(ParamerticVerticalCoordinate): +class AtmosphereSigma(ParametricVerticalCoordinate): """Atmosphere sigma coordinate. Standard name: atmosphere_sigma_coordinate @@ -220,7 +220,7 @@ def from_terms(cls, terms: dict): @dataclass -class AtmosphereHybridSigmaPressure(ParamerticVerticalCoordinate): +class AtmosphereHybridSigmaPressure(ParametricVerticalCoordinate): """Atmosphere hybrid sigma pressure coordinate. Standard name: atmosphere_hybrid_sigma_pressure_coordinate @@ -257,7 +257,7 @@ def __post_init__(self): ) if self.a is not None and self.ap is not None: - raise Exception( + raise ValueError( "Both optional terms 'a' and 'ap' are present in the dataset, please drop one of them." ) @@ -293,7 +293,7 @@ def from_terms(cls, terms: dict): @dataclass -class AtmosphereHybridHeight(ParamerticVerticalCoordinate): +class AtmosphereHybridHeight(ParametricVerticalCoordinate): """Atmosphere hybrid height coordinate. Standard name: atmosphere_hybrid_height_coordinate @@ -352,7 +352,7 @@ def from_terms(cls, terms: dict): @dataclass -class AtmosphereSleve(ParamerticVerticalCoordinate): +class AtmosphereSleve(ParametricVerticalCoordinate): """Atmosphere smooth level vertical (SLEVE) coordinate. Standard name: atmosphere_sleve_coordinate @@ -420,7 +420,7 @@ def from_terms(cls, terms: dict): @dataclass -class OceanSigma(ParamerticVerticalCoordinate): +class OceanSigma(ParametricVerticalCoordinate): """Ocean sigma coordinate. Standard name: ocean_sigma_coordinate @@ -470,7 +470,7 @@ def from_terms(cls, terms: dict): @dataclass -class OceanS(ParamerticVerticalCoordinate): +class OceanS(ParametricVerticalCoordinate): """Ocean s-coordinate. Standard name: ocean_s_coordinate @@ -535,7 +535,7 @@ def from_terms(cls, terms: dict): @dataclass -class OceanSG1(ParamerticVerticalCoordinate): +class OceanSG1(ParametricVerticalCoordinate): """Ocean s-coordinate, generic form 1. Standard name: ocean_s_coordinate_g1 @@ -591,7 +591,7 @@ def from_terms(cls, terms: dict): @dataclass -class OceanSG2(ParamerticVerticalCoordinate): +class OceanSG2(ParametricVerticalCoordinate): """Ocean s-coordinate, generic form 2. Standard name: ocean_s_coordinate_g2 @@ -647,7 +647,7 @@ def from_terms(cls, terms: dict): @dataclass -class OceanSigmaZ(ParamerticVerticalCoordinate): +class OceanSigmaZ(ParametricVerticalCoordinate): """Ocean sigma over z coordinate. Standard name: ocean_sigma_z_coordinate @@ -716,7 +716,7 @@ def from_terms(cls, terms: dict): @dataclass -class OceanDoubleSigma(ParamerticVerticalCoordinate): +class OceanDoubleSigma(ParametricVerticalCoordinate): """Ocean double sigma coordinate. Standard name: ocean_double_sigma_coordinate From c9849f6dcc483f0292292484be29de5658fe0f96 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 22 Oct 2024 15:59:26 -0600 Subject: [PATCH 22/22] tweak docs --- doc/parametricz.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/parametricz.md b/doc/parametricz.md index 77f67924..897ae838 100644 --- a/doc/parametricz.md +++ b/doc/parametricz.md @@ -28,7 +28,7 @@ xr.set_options(display_expand_data=False) 3. {py:attr}`Dataset.cf.formula_terms` ``` -`cf_xarray` supports decoding [parametric vertical coordinates](http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#parametric-vertical-coordinate) encoded in the `formula_terms` attribute using {py:meth}`Dataset.cf.decode_vertical_coords`. Right now, only the two ocean s-coordinates and `ocean_sigma_coordinate` are supported, but support for the [rest](http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#parametric-v-coord) should be easy to add (Pull Requests are very welcome!). +`cf_xarray` supports decoding [parametric vertical coordinates](http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#parametric-vertical-coordinate) encoded in the `formula_terms` attribute using {py:meth}`Dataset.cf.decode_vertical_coords`. ## Decoding parametric coordinates