Skip to content

Commit 3aeea78

Browse files
committed
Fixes and proper array broadcasting
1 parent 7a37063 commit 3aeea78

File tree

4 files changed

+42
-24
lines changed

4 files changed

+42
-24
lines changed

src/earthkit/meteo/stats/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
planned to work with objects like *earthkit.data FieldLists* or *xarray DataSets*.
1616
"""
1717

18+
from .extreme_values import * # noqa
1819
from .numpy_extended import * # noqa
1920
from .quantiles import * # noqa

src/earthkit/meteo/stats/array/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
Statistical functions operating on numpy arrays.
1212
"""
1313

14+
from .extreme_values import * # noqa
1415
from .numpy_extended import * # noqa
1516
from .quantiles import * # noqa

src/earthkit/meteo/stats/array/distributions.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,30 @@
1313

1414

1515
class ContinuousDistribution(abc.ABC):
16-
"""Continuous probabiliy distribution function
16+
"""Continuous probabiliy distribution function.
1717
1818
Partially implements the interface of scipy.stats.rv_continuous, but all
1919
methods should be applicable along an axis so fields can be processed in
2020
a grid point-wise fashion.
2121
"""
2222

23-
@abc.abstracmethod
2423
@classmethod
24+
@abc.abstractmethod
2525
def fit(cls, sample, axis):
26-
"""Determine distribution parameters from a sample of data"""
26+
"""Determine distribution parameters from a sample of data."""
2727

2828
@abc.abstractmethod
2929
def cdf(self, x):
30-
"""Evaluate the continuous distribution function"""
30+
"""Evaluate the continuous distribution function."""
3131

3232
@abc.abstractmethod
3333
def ppf(self, x):
34-
"""Evaluate the inverse CDF (percent point function)"""
34+
"""Evaluate the inverse CDF (percent point function)."""
3535

3636

3737
# Temporary drop-in replacement for scipy.stats.lmoment from scipy v0.15
3838
def _lmoment(sample, order=(1, 2), axis=0):
39-
"""Compute first 2 L-moments of dataset along first axis"""
39+
"""Compute first 2 L-moments of dataset along first axis."""
4040
if len(order) != 2 or order[0] != 1 or order[1] != 2:
4141
raise NotImplementedError
4242
if axis != 0:
@@ -87,8 +87,12 @@ def _lmoment(sample, order=(1, 2), axis=0):
8787
return np.stack([sums[0], sums[1]])
8888

8989

90+
def _expand_dims_after(arr, ndim):
91+
return np.expand_dims(arr, axis=list(range(-ndim, 0)))
92+
93+
9094
class MaxGumbel(ContinuousDistribution):
91-
"""Gumbel distribution for extreme values
95+
"""Gumbel distribution for extreme values.
9296
9397
Parameters
9498
----------
@@ -98,14 +102,12 @@ class MaxGumbel(ContinuousDistribution):
98102
Scale parameter.
99103
"""
100104

101-
def __init__(self, mu, sigma): # TODO: needs an axis argument?
102-
self.mu = np.asarray(mu)
103-
self.sigma = np.asarray(sigma)
104-
assert self.mu.shape == self.sigma.shape
105+
def __init__(self, mu, sigma):
106+
self.mu, self.sigma = np.broadcast_arrays(mu, sigma)
105107

106108
@classmethod
107109
def fit(cls, sample, axis=0):
108-
"""Gumbel distribution with parameters fitted to sample values
110+
"""Gumbel distribution with parameters fitted to sample values.
109111
110112
Parameters
111113
----------
@@ -124,8 +126,18 @@ def fit(cls, sample, axis=0):
124126
mu = lmom[0] - sigma * 0.5772
125127
return cls(mu, sigma)
126128

129+
@property
130+
def shape(self):
131+
"""Tuple of dimensions."""
132+
return self.mu.shape
133+
134+
@property
135+
def ndim(self):
136+
"""Number of dimensions."""
137+
return self.mu.ndim
138+
127139
def cdf(self, x):
128-
"""Evaluate the cumulative distribution function
140+
"""Evaluate the cumulative distribution function.
129141
130142
Parameters
131143
----------
@@ -137,10 +149,11 @@ def cdf(self, x):
137149
The probability that a random variable X from the distribution is less
138150
than or equal to the input x.
139151
"""
140-
return 1.0 - np.exp(-np.exp((self.mu - x) / self.sigma)) # TODO vectorize along axis properly
152+
x = _expand_dims_after(x, self.ndim)
153+
return 1.0 - np.exp(-np.exp((self.mu - x) / self.sigma))
141154

142155
def ppf(self, p):
143-
"""Evaluate the inverse cumulative distribution function
156+
"""Evaluate the inverse cumulative distribution function.
144157
145158
Parameters
146159
----------
@@ -152,4 +165,5 @@ def ppf(self, p):
152165
x such that the probability of a random variable from the distribution
153166
taking a value less than or equal to x is p.
154167
"""
155-
return self.mu - self.sigma * np.log(-np.log(1.0 - p)) # TODO vectorize along axis properly
168+
p = _expand_dims_after(p, self.ndim)
169+
return self.mu - self.sigma * np.log(-np.log(1.0 - p))

src/earthkit/meteo/stats/array/extreme_values.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class MaximumStatistics:
14-
"""Recurrence statistics for a sample of maximum values
14+
"""Recurrence statistics for a sample of maximum values.
1515
1616
Parameters
1717
----------
@@ -22,7 +22,9 @@ class MaximumStatistics:
2222
The axis along which to compute the statistics.
2323
freq: number | timedelta
2424
Temporal frequency of the input data. Used to scale return periods.
25-
Defaults to 1, i.e., no scaling applied.
25+
Defaults to 1, i.e., no scaling applied. When supplying a numpy
26+
timedelta64, unit carries over to return periods, so make sure the
27+
resolution is sufficient.
2628
dist:
2729
Continuous probability distribution fitted to the input data.
2830
@@ -41,16 +43,16 @@ def __init__(self, sample, axis=0, freq=1.0, dist=MaxGumbel):
4143

4244
@property
4345
def dist(self):
44-
"""Estimated ontinuous probability distribution for the data"""
46+
"""Estimated ontinuous probability distribution for the data."""
4547
return self._dist
4648

4749
@property
4850
def freq(self):
49-
"""Temporal frequency used for scaling return periods"""
51+
"""Temporal frequency used for scaling return periods."""
5052
return self._freq
5153

5254
def probability_of_threshold(self, threshold):
53-
"""Probability of exceeding the threshold
55+
"""Probability of exceeding the threshold.
5456
5557
Parameters
5658
----------
@@ -65,7 +67,7 @@ def probability_of_threshold(self, threshold):
6567
return self.dist.cdf(threshold)
6668

6769
def return_period_of_threshold(self, threshold):
68-
"""Return period of exceeding the threshold
70+
"""Return period of exceeding the threshold.
6971
7072
Parameters
7173
----------
@@ -79,7 +81,7 @@ def return_period_of_threshold(self, threshold):
7981
return self.freq / self.probability_of_threshold(threshold)
8082

8183
def threshold_of_probability(self, probability):
82-
"""The threshold of a given probability of exceedance
84+
"""Threshold of a given probability of exceedance.
8385
8486
Parameters
8587
----------
@@ -93,7 +95,7 @@ def threshold_of_probability(self, probability):
9395
return self.dist.ppf(probability)
9496

9597
def threshold_of_return_period(self, return_period):
96-
"""The threshold of a given return period
98+
"""Threshold of a given return period.
9799
98100
Parameters
99101
----------

0 commit comments

Comments
 (0)