Skip to content

Commit 1ba9831

Browse files
committed
Improve test coverage for Label Extension
1 parent 313c1fd commit 1ba9831

File tree

1 file changed

+96
-7
lines changed

1 file changed

+96
-7
lines changed

tests/extensions/test_label.py

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@
1818
from tests.utils import TestCases, assert_to_from_dict, get_temp_dir
1919

2020

21+
class LabelTypeTest(unittest.TestCase):
22+
def test_to_str(self) -> None:
23+
self.assertEqual(str(LabelType.VECTOR), "vector")
24+
self.assertEqual(str(LabelType.RASTER), "raster")
25+
26+
27+
class LabelRelTypeTest(unittest.TestCase):
28+
def test_rel_types(self) -> None:
29+
self.assertEqual(str(LabelRelType.SOURCE), "source")
30+
31+
2132
class LabelTest(unittest.TestCase):
2233
def setUp(self) -> None:
2334
self.maxDiff = None
@@ -28,9 +39,6 @@ def setUp(self) -> None:
2839
"data-files/label/label-example-2.json"
2940
)
3041

31-
def test_rel_types(self) -> None:
32-
self.assertEqual(str(LabelRelType.SOURCE), "source")
33-
3442
def test_to_from_dict(self) -> None:
3543
with open(self.label_example_1_uri, encoding="utf-8") as f:
3644
label_example_1_dict = json.load(f)
@@ -163,7 +171,8 @@ def test_label_classes(self) -> None:
163171
LabelClasses.create(name="label", classes=["seven", "eight"]),
164172
]
165173

166-
LabelExtension.ext(label_item).label_classes = new_classes
174+
label_ext = LabelExtension.ext(label_item)
175+
label_ext.label_classes = new_classes
167176
self.assertEqual(
168177
[
169178
class_name
@@ -173,6 +182,13 @@ def test_label_classes(self) -> None:
173182
["five", "six", "seven", "eight"],
174183
)
175184

185+
self.assertListEqual(
186+
[lc.name for lc in label_ext.label_classes], ["label2", "label"]
187+
)
188+
189+
first_lc = label_ext.label_classes[0]
190+
self.assertEqual("<ClassObject classes=five,six>", first_lc.__repr__())
191+
176192
label_item.validate()
177193

178194
def test_label_tasks(self) -> None:
@@ -219,12 +235,13 @@ def test_label_overviews(self) -> None:
219235

220236
label_counts = get_opt(label_overviews[0].counts)
221237
self.assertEqual(label_counts[1].count, 17)
222-
fisrt_overview_counts = get_opt(label_ext.label_overviews)[0].counts
223-
assert fisrt_overview_counts is not None
224-
fisrt_overview_counts[1].count = 18
238+
first_overview_counts = get_opt(label_ext.label_overviews)[0].counts
239+
assert first_overview_counts is not None
240+
first_overview_counts[1].count = 18
225241
self.assertEqual(
226242
label_item.properties["label:overviews"][0]["counts"][1]["count"], 18
227243
)
244+
self.assertEqual(first_overview_counts[1].name, "two")
228245

229246
label_statistics = get_opt(label_overviews[1].statistics)
230247
self.assertEqual(label_statistics[0].name, "mean")
@@ -271,3 +288,75 @@ def test_label_overviews(self) -> None:
271288
)
272289

273290
label_item.validate()
291+
292+
def test_merge_label_overviews(self) -> None:
293+
294+
overview_1 = LabelOverview.create(
295+
property_key="label",
296+
counts=[
297+
LabelCount.create(name="water", count=25),
298+
LabelCount.create(name="land", count=17),
299+
],
300+
)
301+
overview_2 = LabelOverview.create(
302+
property_key="label",
303+
counts=[
304+
LabelCount.create(name="water", count=10),
305+
LabelCount.create(name="unknown", count=4),
306+
],
307+
)
308+
merged_overview = overview_1.merge_counts(overview_2)
309+
310+
merged_counts = get_opt(merged_overview.counts)
311+
312+
water_count = next(c for c in merged_counts if c.name == "water")
313+
land_count = next(c for c in merged_counts if c.name == "land")
314+
unknown_count = next(c for c in merged_counts if c.name == "unknown")
315+
316+
self.assertEqual(35, water_count.count)
317+
self.assertEqual(17, land_count.count)
318+
self.assertEqual(4, unknown_count.count)
319+
320+
def test_merge_label_overviews_empty_counts(self) -> None:
321+
# Right side is empty
322+
overview_1 = LabelOverview.create(
323+
property_key="label",
324+
counts=[
325+
LabelCount.create(name="water", count=25),
326+
LabelCount.create(name="land", count=17),
327+
],
328+
)
329+
overview_2 = LabelOverview.create(
330+
property_key="label",
331+
counts=None,
332+
)
333+
334+
merged_overview_1 = overview_1.merge_counts(overview_2)
335+
expected_counts = [c.to_dict() for c in get_opt(overview_1.counts)]
336+
actual_counts = [c.to_dict() for c in get_opt(merged_overview_1.counts)]
337+
self.assertListEqual(expected_counts, actual_counts)
338+
339+
# Left side is empty
340+
merged_overview_2 = overview_2.merge_counts(overview_1)
341+
expected_counts = [c.to_dict() for c in get_opt(overview_1.counts)]
342+
actual_counts = [c.to_dict() for c in get_opt(merged_overview_2.counts)]
343+
self.assertEqual(expected_counts, actual_counts)
344+
345+
def test_merge_label_overviews_error(self) -> None:
346+
overview_1 = LabelOverview.create(
347+
property_key="label",
348+
counts=[
349+
LabelCount.create(name="water", count=25),
350+
LabelCount.create(name="land", count=17),
351+
],
352+
)
353+
overview_2 = LabelOverview.create(
354+
property_key="not label",
355+
counts=[
356+
LabelCount.create(name="water", count=10),
357+
LabelCount.create(name="unknown", count=4),
358+
],
359+
)
360+
361+
with self.assertRaises(AssertionError):
362+
_ = overview_1.merge_counts(overview_2)

0 commit comments

Comments
 (0)