Skip to content

Commit 3485498

Browse files
authored
Feature detector extensions (#225)
* Remove unnecessary encoding lines * Docstring revisions * Add interface documentation * Docstring revision * Remove unnecessary encoding line * Add tests * Remove unnecessary import * Add option to restrict the number of blobs based on their scale-space intensity * Add max_num_features keyword argument * Add tests for max_num_features * Add documentation about feature selection when max_num_features is not None * Add area to cell attributes and option to choose the maximum number of cells by area * Add tests for max_num_features * Add normalization of blob intensities by multiplication with sigma**2 * Fix typo in if clause * Adjust the number of columns to include cell area * Sord IDs to ensure correct comparison * Fix incorrect indexing * Remove unused metadata variable * Rename input to input_field to avoid redefining a built-in
1 parent c3cb993 commit 3485498

File tree

7 files changed

+106
-27
lines changed

7 files changed

+106
-27
lines changed

pysteps/feature/blob.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
"""
32
pysteps.feature.blob
43
====================
@@ -15,6 +14,8 @@
1514

1615
from pysteps.exceptions import MissingOptionalDependency
1716

17+
from scipy.ndimage import gaussian_laplace
18+
1819
try:
1920
from skimage import feature
2021

@@ -25,6 +26,7 @@
2526

2627
def detection(
2728
input_image,
29+
max_num_features=None,
2830
method="log",
2931
threshold=0.5,
3032
min_sigma=3,
@@ -47,6 +49,11 @@ def detection(
4749
----------
4850
input_image: array_like
4951
Array of shape (m, n) containing the input image. Nan values are ignored.
52+
max_num_features : int, optional
53+
The maximum number of blobs to detect. Set to None for no restriction.
54+
If specified, the most significant blobs are chosen based on their
55+
intensities in the corresponding Laplacian of Gaussian (LoG)-filtered
56+
images.
5057
method: {'log', 'dog', 'doh'}, optional
5158
The method to use: 'log' = Laplacian of Gaussian, 'dog' = Difference of
5259
Gaussian, 'doh' = Determinant of Hessian.
@@ -95,6 +102,14 @@ def detection(
95102
**kwargs,
96103
)
97104

105+
if max_num_features is not None and blobs.shape[0] > max_num_features:
106+
blob_intensities = []
107+
for i in range(blobs.shape[0]):
108+
gl_image = -gaussian_laplace(input_image, blobs[i, 2]) * blobs[i, 2] ** 2
109+
blob_intensities.append(gl_image[int(blobs[i, 0]), int(blobs[i, 1])])
110+
idx = np.argsort(blob_intensities)[::-1]
111+
blobs = blobs[idx[:max_num_features], :]
112+
98113
if not return_sigmas:
99114
return np.column_stack([blobs[:, 1], blobs[:, 0]])
100115
else:

pysteps/feature/interface.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,27 @@
1-
# -*- coding: utf-8 -*-
21
"""
32
pysteps.feature.interface
43
=========================
54
65
Interface for the feature detection module. It returns a callable function for
7-
detecting features.
6+
detecting features from two-dimensional images.
7+
8+
The feature detectors implement the following interface:
9+
10+
``detection(input_image, **keywords)``
11+
12+
The input is a two-dimensional image. Additional arguments to the specific
13+
method can be given via **keywords. The output is an array of shape (n, m),
14+
where each row corresponds to one of the n features. The first two columns
15+
contain the coordinates (x, y) of the features, and additional information can
16+
be specified in the remaining columns.
17+
18+
All implemented methods support the following keyword arguments:
19+
20+
+------------------+-----------------------------------------------------+
21+
| Key | Value |
22+
+==================+=====================================================+
23+
| max_num_features | maximum number of features to detect |
24+
+------------------+-----------------------------------------------------+
825
926
.. autosummary::
1027
:toctree: ../generated/
@@ -23,10 +40,7 @@
2340

2441

2542
def get_method(name):
26-
"""Return a callable function for computing detection.
27-
28-
Description:
29-
Return a callable function for detecting features on input images .
43+
"""Return a callable function for feature detection.
3044
3145
Implemented methods:
3246

pysteps/feature/shitomasi.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
"""
32
pysteps.feature.shitomasi
43
=========================
@@ -27,6 +26,7 @@
2726
def detection(
2827
input_image,
2928
max_corners=1000,
29+
max_num_features=None,
3030
quality_level=0.01,
3131
min_distance=10,
3232
block_size=5,
@@ -73,6 +73,10 @@ def detection(
7373
method.
7474
It represents the maximum number of points to be tracked (corners).
7575
If set to zero, all detected corners are used.
76+
max_num_features: int, optional
77+
If specified, this argument is substituted for max_corners. Set to None
78+
for no restriction. Added for compatibility with the feature detector
79+
interface.
7680
quality_level: float, optional
7781
The ``qualityLevel`` parameter in the `Shi-Tomasi`_ corner detection
7882
method.
@@ -148,7 +152,7 @@ def detection(
148152
mask = (-1 * mask + 1).astype("uint8")
149153

150154
params = dict(
151-
maxCorners=max_corners,
155+
maxCorners=max_num_features if max_num_features is not None else max_corners,
152156
qualityLevel=quality_level,
153157
minDistance=min_distance,
154158
blockSize=block_size,

pysteps/feature/tstorm.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
"""
32
pysteps.feature.tstorm
43
======================
@@ -49,6 +48,7 @@
4948

5049
def detection(
5150
input_image,
51+
max_num_features=None,
5252
minref=35,
5353
maxref=48,
5454
mindiff=6,
@@ -69,6 +69,9 @@ def detection(
6969
input_image: array-like
7070
Array of shape (m,n) containing input image, usually maximum reflectivity in
7171
dBZ with a resolution of 1 km. Nan values are ignored.
72+
max_num_features : int, optional
73+
The maximum number of cells to detect. Set to None for no restriction.
74+
If specified, the most significant cells are chosen based on their area.
7275
minref: float, optional
7376
Lower threshold for object detection. Lower values will be set to NaN.
7477
The default is 35 dBZ.
@@ -163,10 +166,22 @@ def detection(
163166

164167
cells_id, labels = get_profile(areas, binary, input_image, loc_max, time, minref)
165168

169+
if max_num_features is not None:
170+
idx = np.argsort(cells_id.area.to_numpy())[::-1]
171+
166172
if not output_feat:
167-
return cells_id, labels
173+
if max_num_features is None:
174+
return cells_id, labels
175+
else:
176+
for i in idx[max_num_features:]:
177+
labels[labels == cells_id.ID[i]] = 0
178+
return cells_id.loc[idx[:max_num_features]], labels
168179
if output_feat:
169-
return np.column_stack([np.array(cells_id.cen_x), np.array(cells_id.cen_y)])
180+
out = np.column_stack([np.array(cells_id.cen_x), np.array(cells_id.cen_y)])
181+
if max_num_features is not None:
182+
out = out[idx[:max_num_features], :]
183+
184+
return out
170185

171186

172187
def breakup(ref, minval, maxima):
@@ -219,7 +234,7 @@ def get_profile(areas, binary, ref, loc_max, time, minref):
219234
cells_id = pd.DataFrame(
220235
data=None,
221236
index=range(len(cell_labels)),
222-
columns=["ID", "time", "x", "y", "cen_x", "cen_y", "max_ref", "cont"],
237+
columns=["ID", "time", "x", "y", "cen_x", "cen_y", "max_ref", "cont", "area"],
223238
)
224239
cells_id.time = time
225240
for n in range(len(cell_labels)):
@@ -235,6 +250,7 @@ def get_profile(areas, binary, ref, loc_max, time, minref):
235250
cells_id.cen_x.iloc[n] = int(np.nanmean(cells_id.x[n])) # int(x[0])
236251
cells_id.cen_y.iloc[n] = int(np.nanmean(cells_id.y[n])) # int(y[0])
237252
cells_id.max_ref.iloc[n] = maxref
253+
cells_id.area.iloc[n] = len(cells_id.x.iloc[n])
238254
labels[cells == cell_labels[n]] = ID
239255

240256
return cells_id, labels

pysteps/tests/test_feature.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
import numpy as np
3+
from pysteps import feature
4+
from pysteps.tests.helpers import get_precipitation_fields
5+
6+
arg_names = ["method", "max_num_features"]
7+
arg_values = [("blob", None), ("blob", 5), ("shitomasi", None), ("shitomasi", 5)]
8+
9+
10+
@pytest.mark.parametrize(arg_names, arg_values)
11+
def test_feature(method, max_num_features):
12+
input_field, _ = get_precipitation_fields(0, 0, True, True, None, "mch")
13+
14+
detector = feature.get_method(method)
15+
16+
kwargs = {"max_num_features": max_num_features}
17+
output = detector(input_field.squeeze(), **kwargs)
18+
19+
assert isinstance(output, np.ndarray)
20+
assert output.ndim == 2
21+
assert output.shape[0] > 0
22+
if max_num_features is not None:
23+
assert output.shape[0] <= max_num_features
24+
assert output.shape[1] == 2

pysteps/tests/test_feature_tstorm.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
# -*- coding: utf-8 -*-
2-
3-
import datetime as dt
4-
51
import numpy as np
62
import pytest
73

@@ -14,17 +10,20 @@
1410
except ModuleNotFoundError:
1511
pass
1612

17-
arg_names = ("source", "output_feat", "dry_input")
13+
arg_names = ("source", "output_feat", "dry_input", "max_num_features")
1814

1915
arg_values = [
20-
("mch", False, False),
21-
("mch", True, False),
22-
("mch", False, True),
16+
("mch", False, False, None),
17+
("mch", False, False, 5),
18+
("mch", True, False, None),
19+
("mch", True, False, 5),
20+
("mch", False, True, None),
21+
("mch", False, True, 5),
2322
]
2423

2524

2625
@pytest.mark.parametrize(arg_names, arg_values)
27-
def test_feature_tstorm_detection(source, output_feat, dry_input):
26+
def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_features):
2827

2928
pytest.importorskip("skimage")
3029
pytest.importorskip("pandas")
@@ -37,18 +36,24 @@ def test_feature_tstorm_detection(source, output_feat, dry_input):
3736
input = np.zeros((50, 50))
3837

3938
time = "000"
40-
output = detection(input, time=time, output_feat=output_feat)
39+
output = detection(
40+
input, time=time, output_feat=output_feat, max_num_features=max_num_features
41+
)
4142

4243
if output_feat:
4344
assert isinstance(output, np.ndarray)
4445
assert output.ndim == 2
4546
assert output.shape[1] == 2
47+
if max_num_features is not None:
48+
assert output.shape[0] <= max_num_features
4649
else:
4750
assert isinstance(output, tuple)
4851
assert len(output) == 2
4952
assert isinstance(output[0], DataFrame)
5053
assert isinstance(output[1], np.ndarray)
51-
assert output[0].shape[1] == 8
54+
if max_num_features is not None:
55+
assert output[0].shape[0] <= max_num_features
56+
assert output[0].shape[1] == 9
5257
assert list(output[0].columns) == [
5358
"ID",
5459
"time",
@@ -58,13 +63,14 @@ def test_feature_tstorm_detection(source, output_feat, dry_input):
5863
"cen_y",
5964
"max_ref",
6065
"cont",
66+
"area",
6167
]
6268
assert (output[0].time == time).all()
6369
assert output[1].ndim == 2
6470
assert output[1].shape == input.shape
6571
if not dry_input:
6672
assert output[0].shape[0] > 0
67-
assert list(output[0].ID) == list(np.unique(output[1]))[1:]
73+
assert sorted(list(output[0].ID)) == sorted(list(np.unique(output[1]))[1:])
6874
else:
6975
assert output[0].shape[0] == 0
7076
assert output[1].sum() == 0

pysteps/tests/test_tracking_tdating.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ def test_tracking_tdating_dating(source, dry_input):
4242
assert len(output[2]) == input.shape[0]
4343
assert isinstance(output[1][0], pandas.DataFrame)
4444
assert isinstance(output[2][0], np.ndarray)
45-
assert output[1][0].shape[1] == 8
45+
assert output[1][0].shape[1] == 9
4646
assert output[2][0].shape == input.shape[1:]
4747
if not dry_input:
4848
assert len(output[0]) > 0
4949
assert isinstance(output[0][0], pandas.DataFrame)
50-
assert output[0][0].shape[1] == 8
50+
assert output[0][0].shape[1] == 9
5151
else:
5252
assert len(output[0]) == 0
5353
assert output[1][0].shape[0] == 0

0 commit comments

Comments
 (0)