Skip to content

Commit 4f32df4

Browse files
authored
Merge pull request #28 from comet-toolkit/small_fixes
Small fixes
2 parents f148200 + 7c90cc0 commit 4f32df4

File tree

4 files changed

+100
-56
lines changed

4 files changed

+100
-56
lines changed

obsarray/err_corr.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -69,22 +69,23 @@ def form(self) -> str:
6969
"""Form name"""
7070
pass
7171

72-
def expand_dim_matrix(
73-
self, submatrix: np.ndarray, submatrix_dim: Union[str, List[str]], sli: tuple
74-
):
75-
return expand_errcorr_dims(
76-
in_corr=submatrix,
77-
in_dim=submatrix_dim,
78-
out_dim=list(self._obj[self._unc_var_name][sli].dims),
79-
dim_sizes=self.get_sliced_dim_sizes_uncvar(sli),
80-
)
72+
def get_varshape_errcorr(self):
73+
"""
74+
return shape of uncertainty variable, including only dimensions which are included in the current error correlation form.
75+
76+
:return: shape of included dimensions
77+
"""
78+
all_dims = self._obj[self._unc_var_name].dims
79+
all_dims_sizes = self._obj.sizes
80+
81+
return tuple([all_dims_sizes[dim] for dim in all_dims if dim in self.dims])
8182

8283
def get_sliced_dim_sizes_uncvar(self, sli: tuple) -> dict:
8384
"""
8485
Return dictionary with sizes of sliced dimensions of unc variable, including all dimensions.
8586
8687
:param sli: slice (tuple with slice for each dimension)
87-
:return: shape of included sliced dimensions
88+
:return: dictionary with shape of included sliced dimensions
8889
"""
8990
uncvar_dims = self._obj[self._unc_var_name][sli].dims
9091
uncvar_shape = self._obj[self._unc_var_name][sli].shape
@@ -98,7 +99,7 @@ def get_sliced_dim_sizes_errcorr(self, sli: tuple) -> dict:
9899
included in the current error correlation form.
99100
100101
:param sli: slice (tuple with slice for each dimension)
101-
:return: shape of included sliced dimensions
102+
:return: dictionary with shape of included sliced dimensions
102103
"""
103104
uncvar_sizes = self.get_sliced_dim_sizes_uncvar(sli)
104105
sliced_dims = self.get_sliced_dims_errcorr(sli)
@@ -131,17 +132,21 @@ def get_sliced_shape_errcorr(self, sli: tuple) -> tuple:
131132

132133
return tuple([uncvar_sizes[dim] for dim in sliced_dims])
133134

134-
def slice_full_cov(self, full_matrix: np.ndarray, sli: tuple) -> np.ndarray:
135-
return self.slice_flattened_matrix(
136-
full_matrix, self._obj[self._unc_var_name].shape, sli
137-
)
135+
def slice_errcorr_matrix(self, err_corr_matrix, variable_shape, sli) -> np.ndarray:
136+
"""
137+
Slice the provided error correlation matrix (typically the error correlation matrix of the
138+
BaseErrCorrForm) using the
138139
139-
def slice_flattened_matrix(self, flattened_matrix, variable_shape, sli):
140+
:param err_corr_matrix: error correlation matrix to be sliced
141+
:param variable_shape: tuple with the length of the dimensions in the error correlation matrix (in correct order for flattening)
142+
:param sli: slice of observation variable to return error-correlation matrix for
143+
:return: sliced error correlation matrix
144+
"""
140145
mask_array = np.ones(variable_shape, dtype=bool)
141146
mask_array[sli] = False
142147

143148
return np.delete(
144-
np.delete(flattened_matrix, mask_array.ravel(), 0), mask_array.ravel(), 1
149+
np.delete(err_corr_matrix, mask_array.ravel(), 0), mask_array.ravel(), 1
145150
)
146151

147152
@abc.abstractmethod
@@ -158,16 +163,24 @@ def build_matrix(self, sli: Union[np.ndarray, tuple]) -> np.ndarray:
158163

159164
def build_dot_matrix(self, sli: Union[np.ndarray, tuple]) -> np.ndarray:
160165
"""
161-
Returns uncertainty effect error-correlation matrix, populated with error-correlation values defined
162-
in this parameterisation
166+
Returns expanded error correlation matrix for use in dot product with error correlation
167+
in other dimensions.
163168
164-
:param sli: slice of observation variable to return error-correlation matrix for
169+
The (sliced) error correlation matrix for this BaseErrCorrForm is expanded from its current
170+
(sliced) dimensions (which often don't include all dimensions of the associated uncertainty
171+
variable) to the dimensions of the full (sliced) error correlation (i.e. all dimensions of
172+
the uncertainty).
173+
The returned matrix is not meaningfull unless combined in a dot product with the expanded
174+
matrices of other error correlation matrices (together spanning all uncertainty dimensions).
165175
166-
:return: populated error-correlation matrix
176+
:param sli: slice of observation variable to return error-correlation matrix for
177+
:return: expanded matrix for use in dot product with error correlation in other dimensions.
167178
"""
168-
169-
return self.expand_dim_matrix(
170-
self.build_matrix(sli), self.get_sliced_dims_errcorr(sli), sli
179+
return expand_errcorr_dims(
180+
in_corr=self.build_matrix(sli),
181+
in_dim=self.get_sliced_dims_errcorr(sli),
182+
out_dim=list(self._obj[self._unc_var_name][sli].dims),
183+
dim_sizes=self.get_sliced_dim_sizes_uncvar(sli),
171184
)
172185

173186

@@ -250,24 +263,14 @@ def build_matrix(self, sli: tuple) -> np.ndarray:
250263
251264
:return: populated error-correlation matrix
252265
"""
253-
254266
all_dims = self._obj[self._unc_var_name].dims
255-
all_dims_sizes = self._obj.sizes
256267

257268
sli_submatrix = tuple(
258269
[sli[i] for i in range(len(all_dims)) if all_dims[i] in self.dims]
259270
)
260271

261-
sliced_shape = tuple(
262-
[
263-
all_dims_sizes[all_dims[i]]
264-
for i in range(len(all_dims))
265-
if all_dims[i] in self.dims
266-
]
267-
)
268-
269-
submatrix = self.slice_flattened_matrix(
270-
self._obj[self.params[0]], sliced_shape, sli_submatrix
272+
submatrix = self.slice_errcorr_matrix(
273+
self._obj[self.params[0]], self.get_varshape_errcorr(), sli_submatrix
271274
)
272275

273276
return submatrix

obsarray/test/test_err_corr_forms.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
ErrCorrForms,
88
RandomCorrelation,
99
SystematicCorrelation,
10+
ErrCorrMatrixCorrelation,
1011
)
1112
from obsarray.test.test_unc_accessor import create_ds
1213

@@ -65,18 +66,26 @@ class BasicErrCorrForm(BaseErrCorrForm):
6566
form = "basic"
6667

6768
def build_matrix(self, sli):
68-
return None
69+
full_matrix = np.arange(144).reshape((12, 12))
70+
return self.slice_errcorr_matrix(full_matrix, (2, 2, 3), sli)
6971

7072
self.BasicErrCorrForm = BasicErrCorrForm
7173

74+
def test_get_varshape_errcorr(self):
75+
basicerrcorr = self.BasicErrCorrForm(
76+
self.ds, "u_ran_temperature", ["x", "y", "time"], [], []
77+
)
78+
shape = basicerrcorr.get_varshape_errcorr()
79+
np.testing.assert_equal(shape, (2, 2, 3))
80+
7281
def test_slice_full_cov_full(self):
7382
basicerrcorr = self.BasicErrCorrForm(
7483
self.ds, "u_ran_temperature", ["x"], [], []
7584
)
7685

7786
full_matrix = np.arange(144).reshape((12, 12))
78-
slice_matrix = basicerrcorr.slice_full_cov(
79-
full_matrix, (slice(None), slice(None), slice(None))
87+
slice_matrix = basicerrcorr.build_matrix(
88+
(slice(None), slice(None), slice(None))
8089
)
8190

8291
np.testing.assert_equal(full_matrix, slice_matrix)
@@ -118,13 +127,13 @@ def test_get_sliced_shape_errcorr(self):
118127
shape = basicerrcorr.get_sliced_shape_errcorr((slice(None), 0, slice(0, 2, 1)))
119128
assert shape == (2, 2)
120129

121-
def test_slice_flattened_matrix(self):
130+
def test_slice_errcorr_matrix(self):
122131
basicerrcorr = self.BasicErrCorrForm(
123-
self.ds, "u_ran_temperature", ["x"], [], []
132+
self.ds, "u_ran_temperature", ["x", "y", "z"], [], []
124133
)
125134

126135
full_matrix = np.arange(144).reshape((12, 12))
127-
slice_matrix = basicerrcorr.slice_flattened_matrix(
136+
slice_matrix = basicerrcorr.slice_errcorr_matrix(
128137
full_matrix, (2, 2, 3), (slice(None), slice(None), 0)
129138
)
130139

@@ -136,13 +145,10 @@ def test_slice_flattened_matrix(self):
136145

137146
def test_slice_full_cov_slice(self):
138147
basicerrcorr = self.BasicErrCorrForm(
139-
self.ds, "u_ran_temperature", ["x"], [], []
148+
self.ds, "u_ran_temperature", ["x", "y", "z"], [], []
140149
)
141150

142-
full_matrix = np.arange(144).reshape((12, 12))
143-
slice_matrix = basicerrcorr.slice_full_cov(
144-
full_matrix, (slice(None), slice(None), 0)
145-
)
151+
slice_matrix = basicerrcorr.build_dot_matrix((slice(None), slice(None), 0))
146152

147153
exp_slice_matrix = np.array(
148154
[[0, 3, 6, 9], [36, 39, 42, 45], [72, 75, 78, 81], [108, 111, 114, 117]]
@@ -206,5 +212,40 @@ def test_build_dot_matrix(self):
206212
np.testing.assert_equal((x.dot(y)).dot(time), np.ones((12, 12)))
207213

208214

215+
class TestErrCorrMatrixCorrelation(unittest.TestCase):
216+
def setUp(self) -> None:
217+
self.ds = create_ds()
218+
219+
def test_build_matrix_full(self):
220+
ec = ErrCorrMatrixCorrelation(
221+
self.ds,
222+
"u_str_temperature",
223+
["x", "time"],
224+
["err_corr_str_temperature"],
225+
[],
226+
)
227+
228+
ecrm = ec.build_matrix((slice(None), slice(None), slice(None)))
229+
np.testing.assert_equal(ecrm, np.ones((6, 6)))
230+
231+
def test_build_matrix_sliced(self):
232+
ec = ErrCorrMatrixCorrelation(
233+
self.ds,
234+
"u_str_temperature",
235+
["x", "time"],
236+
["err_corr_str_temperature"],
237+
[],
238+
)
239+
240+
ecrm = ec.build_matrix((0, slice(None), slice(None)))
241+
np.testing.assert_equal(ecrm, np.ones((3, 3)))
242+
243+
ecrm = ec.build_matrix((slice(None), 0, slice(None)))
244+
np.testing.assert_equal(ecrm, np.ones((6, 6)))
245+
246+
ecrm = ec.build_matrix((slice(None), slice(None), 0))
247+
np.testing.assert_equal(ecrm, np.ones((2, 2)))
248+
249+
209250
if __name__ == "main":
210251
unittest.main()

obsarray/test/test_unc_accessor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def create_ds():
9898
"units": "K",
9999
"err_corr": [
100100
{
101-
"dim": "x",
101+
"dim": ["x", "time"],
102102
"form": "err_corr_matrix",
103103
"params": ["err_corr_str_temperature"],
104104
},
@@ -107,19 +107,19 @@ def create_ds():
107107
"form": "systematic",
108108
"params": [],
109109
},
110-
{
111-
"dim": "time",
112-
"form": "systematic",
113-
"params": [],
114-
},
115110
],
116111
"pdf_shape": "gaussian",
117112
},
118113
)
119114

120115
ds["err_corr_str_temperature"] = (
121-
["x", "x"],
122-
np.eye(temperature.shape[0]),
116+
["x.time", "x.time"],
117+
np.ones(
118+
(
119+
temperature.shape[0] * temperature.shape[2],
120+
temperature.shape[0] * temperature.shape[2],
121+
)
122+
),
123123
)
124124

125125
return ds

obsarray/unc_accessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def err_cov_matrix(self):
354354
err_cov_matrix = empty_err_corr_matrix(self._obj[self._unc_var_name][self._sli])
355355

356356
err_cov_matrix.values = convert_corr_to_cov(
357-
self.err_corr_matrix().values, self.value.values
357+
self.err_corr_matrix().values, self.abs_value.values
358358
)
359359

360360
return err_cov_matrix

0 commit comments

Comments
 (0)