Skip to content

Commit e371fb2

Browse files
committed
Define equality for Label Extension component classes
1 parent 5539471 commit e371fb2

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

pystac/extensions/label.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ def __repr__(self) -> str:
133133
",".join([str(x) for x in self.classes])
134134
)
135135

136+
def __eq__(self, o: object) -> bool:
137+
if isinstance(o, LabelClasses):
138+
o = o.to_dict()
139+
140+
if not isinstance(o, dict):
141+
return NotImplemented
142+
143+
return self.to_dict() == o
144+
136145
def to_dict(self) -> Dict[str, Any]:
137146
"""Returns the dictionary representing the JSON of this instance."""
138147
return self.properties
@@ -192,6 +201,15 @@ def to_dict(self) -> Dict[str, Any]:
192201
"""Returns the dictionary representing the JSON of this instance."""
193202
return self.properties
194203

204+
def __eq__(self, o: object) -> bool:
205+
if isinstance(o, LabelCount):
206+
o = o.to_dict()
207+
208+
if not isinstance(o, dict):
209+
return NotImplemented
210+
211+
return self.to_dict() == o
212+
195213

196214
class LabelStatistics:
197215
"""Contains statistics for regression/continuous numeric value data.
@@ -246,6 +264,15 @@ def to_dict(self) -> Dict[str, Any]:
246264
"""Returns the dictionary representing the JSON of this LabelStatistics."""
247265
return self.properties
248266

267+
def __eq__(self, o: object) -> bool:
268+
if isinstance(o, LabelStatistics):
269+
o = o.to_dict()
270+
271+
if not isinstance(o, dict):
272+
return NotImplemented
273+
274+
return self.to_dict() == o
275+
249276

250277
class LabelOverview:
251278
"""Stores counts (for classification-type data) or summary statistics (for
@@ -391,6 +418,15 @@ def to_dict(self) -> Dict[str, Any]:
391418
"""Returns the dictionary representing the JSON of this LabelOverview."""
392419
return self.properties
393420

421+
def __eq__(self, o: object) -> bool:
422+
if isinstance(o, LabelOverview):
423+
o = o.to_dict()
424+
425+
if not isinstance(o, dict):
426+
return NotImplemented
427+
428+
return self.to_dict() == o
429+
394430

395431
class LabelExtension(ExtensionManagementMixin[pystac.Item]):
396432
"""A class that can be used to extend the properties of an

tests/extensions/test_label.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import unittest
44
import tempfile
5+
from typing import List, Union
56

67
import pystac
78
from pystac import Catalog, Item, CatalogType
@@ -30,6 +31,59 @@ def test_rel_types(self) -> None:
3031
self.assertEqual(str(LabelRelType.SOURCE), "source")
3132

3233

34+
class LabelTaskTest(unittest.TestCase):
35+
def test_rel_types(self) -> None:
36+
self.assertEqual(str(LabelTask.REGRESSION), "regression")
37+
self.assertEqual(str(LabelTask.CLASSIFICATION), "classification")
38+
self.assertEqual(str(LabelTask.DETECTION), "detection")
39+
self.assertEqual(str(LabelTask.SEGMENTATION), "segmentation")
40+
41+
42+
class LabelCountTest(unittest.TestCase):
43+
def test_label_count_equality(self) -> None:
44+
count1 = LabelCount.create(name="prop", count=1)
45+
count2 = LabelCount.create(name="prop", count=1)
46+
count3 = LabelCount.create(name="other", count=1)
47+
count4 = LabelCount.create(name="prop", count=2)
48+
49+
self.assertEqual(count1, count2)
50+
self.assertNotEqual(count1, count3)
51+
self.assertNotEqual(count1, count4)
52+
self.assertNotEqual(count1, 42)
53+
54+
55+
class LabelOverviewTest(unittest.TestCase):
56+
def test_label_count_equality(self) -> None:
57+
stats1 = LabelStatistics.create(name="prop", value=42.3)
58+
count1 = LabelCount.create(name="prop", count=1)
59+
60+
overview1 = LabelOverview.create(
61+
property_key="first", counts=[count1], statistics=[stats1]
62+
)
63+
overview2 = LabelOverview.create(
64+
property_key="first", counts=[count1], statistics=[stats1]
65+
)
66+
overview3 = LabelOverview.create(property_key="first", counts=[count1])
67+
overview4 = LabelOverview.create(property_key="first", statistics=[stats1])
68+
self.assertEqual(overview1, overview2)
69+
self.assertNotEqual(overview1, overview3)
70+
self.assertNotEqual(overview1, overview4)
71+
self.assertNotEqual(overview1, 42)
72+
73+
74+
class LabelStatisticsTest(unittest.TestCase):
75+
def test_label_statistics_equality(self) -> None:
76+
stats1 = LabelStatistics.create(name="prop", value=42.3)
77+
stats2 = LabelStatistics.create(name="prop", value=42.3)
78+
stats3 = LabelStatistics.create(name="other", value=42.3)
79+
stats4 = LabelStatistics.create(name="prop", value=73.4)
80+
81+
self.assertEqual(stats1, stats2)
82+
self.assertNotEqual(stats1, stats3)
83+
self.assertNotEqual(stats1, stats4)
84+
self.assertNotEqual(stats1, 42)
85+
86+
3387
class LabelTest(unittest.TestCase):
3488
def setUp(self) -> None:
3589
self.maxDiff = None

0 commit comments

Comments
 (0)