Skip to content

Commit b31fa09

Browse files
committed
initial commit of obs_concat defining interface
1 parent e69e81f commit b31fa09

File tree

2 files changed

+255
-0
lines changed

2 files changed

+255
-0
lines changed

obsarray/concat.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""obsarray.concat - module with extension to xarray.concat for obs_vars and unc_vars"""
2+
3+
import numpy as np
4+
import xarray as xr
5+
from typing import Union, Any
6+
from xarray.core.types import T_Dataset, T_DataArray, T_Variable
7+
from collections.abc import Hashable, Iterable
8+
9+
10+
__author__ = "Sam Hunt <sam.hunt@npl.co.uk>"
11+
__all__ = []
12+
13+
14+
def obs_concat(
15+
objs: Iterable[T_DataArray],
16+
dim: Union[Hashable, T_Variable, T_DataArray, Any],
17+
unc: T_Dataset,
18+
dim_err_corr: Iterable[Any],
19+
combine_unc: str = "concat",
20+
*args,
21+
**kwargs
22+
):
23+
"""
24+
Concatenate xarray *obs_vars* along a new or existing dimension, safely handling also
25+
concatenating associated *unc_vars*. Extension to :py:func:`xarray.concat`.
26+
27+
:param objs: sequence of :py:class:`xarray.Dataset` and :py:class:`xarray.DataArray`
28+
xarray objects to concatenate *obs_vars* together. As for :py:class:`xarray.Dataset`,
29+
each object is expected to consist of variables and coordinates with matching shapes
30+
except for along the concatenated dimension.
31+
:param dim: Name of the dimension to concatenate along. This can either be a new
32+
dimension name, in which case it is added along axis=0, or an existing
33+
dimension name, in which case the location of the dimension is
34+
unchanged. If dimension is provided as a Variable, DataArray or Index, its name
35+
is used as the dimension to concatenate along and the values are added
36+
as a coordinate.
37+
:param unc: dataset containing the unc_vars associated with objs
38+
:param dim_err_corr: error-correlation form definition for concatenation dimension
39+
:param combine_unc: string indicating how to concatenate unc_vars.
40+
41+
* "concat": (default) merges *unc_vars* as for *obs_vars* - assumes *unc_var* order is the same between *obs_vars*
42+
* "no_concat": expands each *unc_var* along dim, gap filling with zeros
43+
"""
44+
45+
concat_obs_vars = xr.concat(objs, dim, *args, **kwargs)
46+
concat_unc_vars = None
47+
return concat_obs_vars, concat_unc_vars
48+
49+
50+
51+
52+
if __name__ == "__main__":
53+
pass

obsarray/test/test_concat.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
"""test_err_corr_forms - tests for obsarray.err_corr_forms"""
2+
3+
import unittest
4+
import numpy as np
5+
import obsarray
6+
import xarray as xr
7+
8+
9+
__author__ = "Sam Hunt <sam.hunt@npl.co.uk>"
10+
__all__ = []
11+
12+
from obsarray.concat import obs_concat
13+
14+
15+
def create_test_ds():
16+
c1a = np.ones((4, 3))
17+
c2a = np.ones((4, 3))
18+
19+
c1b = np.ones((7, 5, 3))
20+
c2b = np.ones((7, 5, 3))
21+
22+
d1a = np.ones((4, 3)) * 1
23+
d2a = np.ones((4, 3)) * 2
24+
s1a = np.ones((4, 3)) * 3
25+
s2a = np.ones((4, 3)) * 4
26+
s3a = np.ones((4, 3)) * 5
27+
28+
d1b = np.ones((7, 5, 3)) * 1
29+
30+
d1a_attrs = {"units": "test_units", "geometry": "a", "measurand": "d", 'm': 10}
31+
d2a_attrs = {"units": "test_units", "geometry": "a", "measurand": "d", 'm': 11}
32+
s1a_attrs = {"units": "test_units", "geometry": "a", "measurand": "s", 'm': 12}
33+
s2a_attrs = {"units": "test_units", "geometry": "a", "measurand": "s", 'm': 4}
34+
s3a_attrs = {"units": "test_units", "geometry": "a", "measurand": "s", 'm': 5}
35+
d1b_attrs = {"units": "test_units", "geometry": "b", "measurand": "d"}
36+
37+
ds = xr.Dataset(
38+
{
39+
"d1a": (["xa", "ya"], d1a, d1a_attrs),
40+
"d2a": (["xa", "ya"], d2a, d2a_attrs),
41+
"s1a": (["xa", "ya"], s1a, s1a_attrs),
42+
"s2a": (["xa", "ya"], s2a, s2a_attrs),
43+
"s3a": (["xa", "ya"], s3a, s3a_attrs),
44+
"d1b": (["xb", "yb", "zb"], d1b, d1b_attrs),
45+
},
46+
coords={
47+
"c1a": (["xa", "ya"], c1a),
48+
"c2a": (["xa", "ya"], c2a),
49+
"c1b": (["xb", "yb", "zb"], c1b),
50+
"c2b": (["xb", "yb", "zb"], c2b),
51+
},
52+
attrs={
53+
"history": "test_history",
54+
"meas_vars": ["d1a", "d2a", "s1a", "s2a", "s3a", "d1b"],
55+
},
56+
)
57+
58+
for var in ["d1a", "d2a"]:
59+
ds.unc[var]["u_r_" + var] = (["xa", "ya"], ds[var].values, {})
60+
61+
err_corr_def = [
62+
{
63+
"dim": ["xa", "ya"],
64+
"form": "systematic",
65+
"params": [],
66+
"units": []
67+
}
68+
]
69+
70+
ds.unc[var]["u_s_" + var] = (["xa", "ya"], ds[var].values, {"err_corr": err_corr_def})
71+
72+
return ds
73+
74+
def create_concat_ds():
75+
c1a = np.ones((4, 3))
76+
c2a = np.ones((4, 3))
77+
78+
c1b = np.ones((7, 5, 3))
79+
c2b = np.ones((7, 5, 3))
80+
81+
da = np.ones((4, 3, 2)) * 1
82+
da[:, : 1] = 2
83+
84+
s1a = np.ones((4, 3)) * 3
85+
s2a = np.ones((4, 3)) * 4
86+
s3a = np.ones((4, 3)) * 5
87+
88+
d1b = np.ones((7, 5, 3)) * 1
89+
90+
da_attrs = {"units": "test_units", "geometry": "a", "measurand": "d", 'm': 10}
91+
s1a_attrs = {"units": "test_units", "geometry": "a", "measurand": "s", 'm': 12}
92+
s2a_attrs = {"units": "test_units", "geometry": "a", "measurand": "s", 'm': 4}
93+
s3a_attrs = {"units": "test_units", "geometry": "a", "measurand": "s", 'm': 5}
94+
d1b_attrs = {"units": "test_units", "geometry": "b", "measurand": "d"}
95+
96+
ds = xr.Dataset(
97+
{
98+
"da": (["xa", "ya"], da, da_attrs),
99+
"s1a": (["xa", "ya"], s1a, s1a_attrs),
100+
"s2a": (["xa", "ya"], s2a, s2a_attrs),
101+
"s3a": (["xa", "ya"], s3a, s3a_attrs),
102+
"d1b": (["xb", "yb", "zb"], d1b, d1b_attrs),
103+
},
104+
coords={
105+
"c1a": (["xa", "ya"], c1a),
106+
"c2a": (["xa", "ya"], c2a),
107+
"c1b": (["xb", "yb", "zb"], c1b),
108+
"c2b": (["xb", "yb", "zb"], c2b),
109+
},
110+
attrs={
111+
"history": "test_history",
112+
"meas_vars": ["d1a", "d2a", "s1a", "s2a", "s3a", "d1b"],
113+
},
114+
)
115+
116+
for var in ["d1a", "d2a"]:
117+
ds.unc[var]["u_r_" + var] = (["xa", "ya"], ds[var].values, {})
118+
119+
err_corr_def = [
120+
{
121+
"dim": ["xa", "ya"],
122+
"form": "systematic",
123+
"params": [],
124+
"units": []
125+
}
126+
]
127+
128+
ds.unc[var]["u_s_" + var] = (["xa", "ya"], ds[var].values, {"err_corr": err_corr_def})
129+
130+
return ds
131+
132+
133+
134+
135+
class TestConcat(unittest.TestCase):
136+
137+
def test_concat_combine_unc_concat(self):
138+
ds = create_test_ds()
139+
140+
obs_vars, unc_vars = obs_concat([ds["d1a"], ds["d2a"]], "new_dim", ds, "concat")
141+
142+
c1a = np.ones((4, 3))
143+
c2a = np.ones((4, 3))
144+
145+
c1b = np.ones((7, 5, 3))
146+
c2b = np.ones((7, 5, 3))
147+
148+
da = np.ones((4, 3, 2)) * 1
149+
da[:, : 1] = 2
150+
151+
s1a = np.ones((4, 3)) * 3
152+
s2a = np.ones((4, 3)) * 4
153+
s3a = np.ones((4, 3)) * 5
154+
155+
d1b = np.ones((7, 5, 3)) * 1
156+
157+
da_attrs = {"units": "test_units", "geometry": "a", "measurand": "d", 'm': 10}
158+
s1a_attrs = {"units": "test_units", "geometry": "a", "measurand": "s", 'm': 12}
159+
s2a_attrs = {"units": "test_units", "geometry": "a", "measurand": "s", 'm': 4}
160+
s3a_attrs = {"units": "test_units", "geometry": "a", "measurand": "s", 'm': 5}
161+
d1b_attrs = {"units": "test_units", "geometry": "b", "measurand": "d"}
162+
163+
ds = xr.Dataset(
164+
{
165+
"da": (["xa", "ya"], da, da_attrs),
166+
"s1a": (["xa", "ya"], s1a, s1a_attrs),
167+
"s2a": (["xa", "ya"], s2a, s2a_attrs),
168+
"s3a": (["xa", "ya"], s3a, s3a_attrs),
169+
"d1b": (["xb", "yb", "zb"], d1b, d1b_attrs),
170+
},
171+
coords={
172+
"c1a": (["xa", "ya"], c1a),
173+
"c2a": (["xa", "ya"], c2a),
174+
"c1b": (["xb", "yb", "zb"], c1b),
175+
"c2b": (["xb", "yb", "zb"], c2b),
176+
},
177+
attrs={
178+
"history": "test_history",
179+
"meas_vars": ["d1a", "d2a", "s1a", "s2a", "s3a", "d1b"],
180+
},
181+
)
182+
183+
for var in ["d1a", "d2a"]:
184+
ds.unc[var]["u_r_" + var] = (["xa", "ya"], ds[var].values, {})
185+
186+
err_corr_def = [
187+
{
188+
"dim": ["xa", "ya"],
189+
"form": "systematic",
190+
"params": [],
191+
"units": []
192+
}
193+
]
194+
195+
ds.unc[var]["u_s_" + var] = (["xa", "ya"], ds[var].values, {"err_corr": err_corr_def})
196+
197+
return ds
198+
199+
200+
201+
if __name__ == "__main__":
202+
pass

0 commit comments

Comments
 (0)