Skip to content

Commit 3e433ff

Browse files
ritvjednerini
andauthored
Consider splits and merges in tdating (#349)
* Ignore 0 in cell match Since 0 in labels means that no cell was identified in that pixel, we need to ignore 0 when matching. Otherwise we could have a situation where e.g the cell overlaps 0.55 with 0 and 0.45 with some cell and it would end up unmatched (`ID_coverage` would be 0 and a new track would be initialized) * Get required match overlap frac as parameter * Store information about splits If the advected cell overlaps > 10% with more than one cell at next timestep, consider the advected cell as a split cell. Store information of which IDs the cell split to. Also for cells at current timestep, mark cells that resulted from splits * Store information about merges If the cell at current timestep is overlapped > 10% by more than one advected cell, consider it merged and store IDs of previous cells. Also mark cells from previous timestep is they will merge at next timestep. * Add columns to cell dataframes * Make splits/merges output optional * Add options to specify fractions required for matching/splitting/merging * Update tests to account for split/merge output in tdating * Fix unused variables code check * Fix match_frac argument in tracking * Refactor to avoid chained assignment warnings in pandas * Add short example --------- Co-authored-by: Daniele Nerini <daniele.nerini@gmail.com>
1 parent d77fe73 commit 3e433ff

File tree

5 files changed

+321
-59
lines changed

5 files changed

+321
-59
lines changed

examples/thunderstorm_detection_and_tracking.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,17 @@
9090
# Properties of one of the identified cells:
9191
print(cells_id.iloc[0])
9292

93+
###############################################################################
94+
# Optionally, one can also ask to consider splits and merges of thunderstorm cells.
95+
# A cell at time t is considered to split if it will verlap more than 10% with more than
96+
# one cell at time t+1. Conversely, a cell is considered to be a merge, if more
97+
# than one cells fron time t will overlap more than 10% with it.
98+
99+
cells_id, labels = tstorm_detect.detection(
100+
input_image, time=time, output_splits_merges=True
101+
)
102+
print(cells_id.iloc[0])
103+
93104
###############################################################################
94105
# Example of thunderstorm tracking over a timeseries
95106
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

pysteps/feature/tstorm.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def detection(
5858
minmax=41,
5959
mindis=10,
6060
output_feat=False,
61+
output_splits_merges=False,
6162
time="000000000",
6263
):
6364
"""
@@ -93,6 +94,10 @@ def detection(
9394
smaller distance will be merged. The default is 10 km.
9495
output_feat: bool, optional
9596
Set to True to return only the cell coordinates.
97+
output_split_merge: bool, optional
98+
Set to True to return additional columns in the dataframe for describing the
99+
splitting and merging of cells. Note that columns are initialized with None,
100+
and the information needs to be analyzed while tracking.
96101
time: string, optional
97102
Date and time as string. Used to label time in the resulting dataframe.
98103
The default is '000000000'.
@@ -166,7 +171,15 @@ def detection(
166171

167172
areas, lines = breakup(input_image, np.nanmin(input_image.flatten()), maxima_dis)
168173

169-
cells_id, labels = get_profile(areas, binary, input_image, loc_max, time, minref)
174+
cells_id, labels = get_profile(
175+
areas,
176+
binary,
177+
input_image,
178+
loc_max,
179+
time,
180+
minref,
181+
output_splits_merges=output_splits_merges,
182+
)
170183

171184
if max_num_features is not None:
172185
idx = np.argsort(cells_id.area.to_numpy())[::-1]
@@ -225,10 +238,12 @@ def longdistance(loc_max, mindis):
225238
return new_max
226239

227240

228-
def get_profile(areas, binary, ref, loc_max, time, minref):
241+
def get_profile(areas, binary, ref, loc_max, time, minref, output_splits_merges=False):
229242
"""
230243
This function returns the identified cells in a dataframe including their x,y
231244
locations, location of their maxima, maximum reflectivity and contours.
245+
Optionally, the dataframe can include columns for storing information regarding
246+
splitting and merging of cells.
232247
"""
233248
cells = areas * binary
234249
cell_labels = cells[loc_max]
@@ -255,11 +270,47 @@ def get_profile(areas, binary, ref, loc_max, time, minref):
255270
"area": len(x),
256271
}
257272
)
273+
if output_splits_merges:
274+
cells_id[-1].update(
275+
{
276+
"splitted": None,
277+
"split_IDs": None,
278+
"merged": None,
279+
"merged_IDs": None,
280+
"results_from_split": None,
281+
"will_merge": None,
282+
}
283+
)
258284
labels[cells == cell_labels[n]] = this_id
285+
286+
columns = [
287+
"ID",
288+
"time",
289+
"x",
290+
"y",
291+
"cen_x",
292+
"cen_y",
293+
"max_ref",
294+
"cont",
295+
"area",
296+
]
297+
if output_splits_merges:
298+
columns.extend(
299+
[
300+
"splitted",
301+
"split_IDs",
302+
"merged",
303+
"merged_IDs",
304+
"results_from_split",
305+
"will_merge",
306+
]
307+
)
259308
cells_id = pd.DataFrame(
260309
data=cells_id,
261310
index=range(len(cell_labels)),
262-
columns=["ID", "time", "x", "y", "cen_x", "cen_y", "max_ref", "cont", "area"],
311+
columns=columns,
263312
)
264-
313+
if output_splits_merges:
314+
cells_id["split_IDs"] = cells_id["split_IDs"].astype("object")
315+
cells_id["merged_IDs"] = cells_id["merged_IDs"].astype("object")
265316
return cells_id, labels

pysteps/tests/test_feature_tstorm.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,29 @@
1010
except ModuleNotFoundError:
1111
pass
1212

13-
arg_names = ("source", "output_feat", "dry_input", "max_num_features")
13+
arg_names = (
14+
"source",
15+
"output_feat",
16+
"dry_input",
17+
"max_num_features",
18+
"output_split_merge",
19+
)
1420

1521
arg_values = [
16-
("mch", False, False, None),
17-
("mch", False, False, 5),
18-
("mch", True, False, None),
19-
("mch", True, False, 5),
20-
("mch", False, True, None),
21-
("mch", False, True, 5),
22+
("mch", False, False, None, False),
23+
("mch", False, False, 5, False),
24+
("mch", True, False, None, False),
25+
("mch", True, False, 5, False),
26+
("mch", False, True, None, False),
27+
("mch", False, True, 5, False),
28+
("mch", False, False, None, True),
2229
]
2330

2431

2532
@pytest.mark.parametrize(arg_names, arg_values)
26-
def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_features):
33+
def test_feature_tstorm_detection(
34+
source, output_feat, dry_input, max_num_features, output_split_merge
35+
):
2736
pytest.importorskip("skimage")
2837
pytest.importorskip("pandas")
2938

@@ -36,7 +45,11 @@ def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_featur
3645

3746
time = "000"
3847
output = detection(
39-
input, time=time, output_feat=output_feat, max_num_features=max_num_features
48+
input,
49+
time=time,
50+
output_feat=output_feat,
51+
max_num_features=max_num_features,
52+
output_splits_merges=output_split_merge,
4053
)
4154

4255
if output_feat:
@@ -45,6 +58,40 @@ def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_featur
4558
assert output.shape[1] == 2
4659
if max_num_features is not None:
4760
assert output.shape[0] <= max_num_features
61+
elif output_split_merge:
62+
assert isinstance(output, tuple)
63+
assert len(output) == 2
64+
assert isinstance(output[0], DataFrame)
65+
assert isinstance(output[1], np.ndarray)
66+
if max_num_features is not None:
67+
assert output[0].shape[0] <= max_num_features
68+
assert output[0].shape[1] == 15
69+
assert list(output[0].columns) == [
70+
"ID",
71+
"time",
72+
"x",
73+
"y",
74+
"cen_x",
75+
"cen_y",
76+
"max_ref",
77+
"cont",
78+
"area",
79+
"splitted",
80+
"split_IDs",
81+
"merged",
82+
"merged_IDs",
83+
"results_from_split",
84+
"will_merge",
85+
]
86+
assert (output[0].time == time).all()
87+
assert output[1].ndim == 2
88+
assert output[1].shape == input.shape
89+
if not dry_input:
90+
assert output[0].shape[0] > 0
91+
assert sorted(list(output[0].ID)) == sorted(list(np.unique(output[1]))[1:])
92+
else:
93+
assert output[0].shape[0] == 0
94+
assert output[1].sum() == 0
4895
else:
4996
assert isinstance(output, tuple)
5097
assert len(output) == 2

pysteps/tests/test_tracking_tdating.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,24 @@
77
from pysteps.utils import to_reflectivity
88
from pysteps.tests.helpers import get_precipitation_fields
99

10-
arg_names = ("source", "dry_input")
10+
arg_names = ("source", "dry_input", "output_splits_merges")
1111

1212
arg_values = [
13-
("mch", False),
14-
("mch", False),
15-
("mch", True),
13+
("mch", False, False),
14+
("mch", False, False),
15+
("mch", True, False),
16+
("mch", False, True),
1617
]
1718

18-
arg_names_multistep = ("source", "len_timesteps")
19+
arg_names_multistep = ("source", "len_timesteps", "output_splits_merges")
1920
arg_values_multistep = [
20-
("mch", 6),
21+
("mch", 6, False),
22+
("mch", 6, True),
2123
]
2224

2325

2426
@pytest.mark.parametrize(arg_names_multistep, arg_values_multistep)
25-
def test_tracking_tdating_dating_multistep(source, len_timesteps):
27+
def test_tracking_tdating_dating_multistep(source, len_timesteps, output_splits_merges):
2628
pytest.importorskip("skimage")
2729

2830
input_fields, metadata = get_precipitation_fields(
@@ -37,6 +39,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):
3739
input_fields[0 : len_timesteps // 2],
3840
timelist[0 : len_timesteps // 2],
3941
mintrack=1,
42+
output_splits_merges=output_splits_merges,
4043
)
4144
# Second half of timesteps
4245
tracks_2, cells, _ = dating(
@@ -46,6 +49,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):
4649
start=2,
4750
cell_list=cells,
4851
label_list=labels,
52+
output_splits_merges=output_splits_merges,
4953
)
5054

5155
# Since we are adding cells, number of tracks should increase
@@ -67,7 +71,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):
6771

6872

6973
@pytest.mark.parametrize(arg_names, arg_values)
70-
def test_tracking_tdating_dating(source, dry_input):
74+
def test_tracking_tdating_dating(source, dry_input, output_splits_merges):
7175
pytest.importorskip("skimage")
7276
pandas = pytest.importorskip("pandas")
7377

@@ -80,7 +84,13 @@ def test_tracking_tdating_dating(source, dry_input):
8084

8185
timelist = metadata["timestamps"]
8286

83-
output = dating(input, timelist, mintrack=1)
87+
cell_column_length = 9
88+
if output_splits_merges:
89+
cell_column_length = 15
90+
91+
output = dating(
92+
input, timelist, mintrack=1, output_splits_merges=output_splits_merges
93+
)
8494

8595
# Check output format
8696
assert isinstance(output, tuple)
@@ -92,12 +102,12 @@ def test_tracking_tdating_dating(source, dry_input):
92102
assert len(output[2]) == input.shape[0]
93103
assert isinstance(output[1][0], pandas.DataFrame)
94104
assert isinstance(output[2][0], np.ndarray)
95-
assert output[1][0].shape[1] == 9
105+
assert output[1][0].shape[1] == cell_column_length
96106
assert output[2][0].shape == input.shape[1:]
97107
if not dry_input:
98108
assert len(output[0]) > 0
99109
assert isinstance(output[0][0], pandas.DataFrame)
100-
assert output[0][0].shape[1] == 9
110+
assert output[0][0].shape[1] == cell_column_length
101111
else:
102112
assert len(output[0]) == 0
103113
assert output[1][0].shape[0] == 0

0 commit comments

Comments
 (0)