Skip to content

Commit f1f02a2

Browse files
committed
Use cls argument in from_dict methods
1 parent a729cd2 commit f1f02a2

File tree

5 files changed

+60
-3
lines changed

5 files changed

+60
-3
lines changed

pystac/catalog.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@ def from_dict(
904904
if migrate:
905905
result = pystac.read_dict(d, href=href, root=root)
906906
if not isinstance(result, Catalog):
907-
raise pystac.STACError(f"{result} is not a Catalog")
907+
raise pystac.STACTypeError(f"{result} is not a Catalog")
908908
return result
909909

910910
catalog_type = CatalogType.determine_type(d)
@@ -919,7 +919,7 @@ def from_dict(
919919

920920
d.pop("stac_version")
921921

922-
cat = Catalog(
922+
cat = cls(
923923
id=id,
924924
description=description,
925925
title=title,

pystac/collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def from_dict(
610610

611611
d.pop("stac_version")
612612

613-
collection = Collection(
613+
collection = cls(
614614
id=id,
615615
description=description,
616616
extent=extent,

tests/test_catalog.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,3 +1107,22 @@ def test_full_copy_4(self) -> None:
11071107
].get_absolute_href()
11081108
assert href is not None
11091109
self.assertTrue(os.path.exists(href))
1110+
1111+
1112+
class CatalogSubClassTest(unittest.TestCase):
1113+
"""This tests cases related to creating classes inheriting from pystac.Catalog to
1114+
ensure that inheritance, class methods, etc. function as expected."""
1115+
1116+
TEST_CASE_1 = TestCases.get_path("data-files/catalogs/test-case-1/catalog.json")
1117+
1118+
class BasicCustomCatalog(pystac.Catalog):
1119+
pass
1120+
1121+
def setUp(self) -> None:
1122+
self.stac_io = pystac.StacIO.default()
1123+
1124+
def test_from_dict_returns_subclass(self) -> None:
1125+
1126+
catalog_dict = self.stac_io.read_json(self.TEST_CASE_1)
1127+
custom_catalog = self.BasicCustomCatalog.from_dict(catalog_dict)
1128+
self.assertIsInstance(custom_catalog, self.BasicCustomCatalog)

tests/test_collection.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,22 @@ def test_from_items(self) -> None:
264264

265265
self.assertEqual(interval[0], datetime(2000, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC))
266266
self.assertEqual(interval[1], datetime(2001, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC))
267+
268+
269+
class CollectionSubClassTest(unittest.TestCase):
270+
"""This tests cases related to creating classes inheriting from pystac.Catalog to
271+
ensure that inheritance, class methods, etc. function as expected."""
272+
273+
MULTI_EXTENT = TestCases.get_path("data-files/collections/multi-extent.json")
274+
275+
class BasicCustomCollection(pystac.Collection):
276+
pass
277+
278+
def setUp(self) -> None:
279+
self.stac_io = pystac.StacIO.default()
280+
281+
def test_from_dict_returns_subclass(self) -> None:
282+
283+
collection_dict = self.stac_io.read_json(self.MULTI_EXTENT)
284+
custom_collection = self.BasicCustomCollection.from_dict(collection_dict)
285+
self.assertIsInstance(custom_collection, self.BasicCustomCollection)

tests/test_item.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,3 +698,22 @@ def test_asset_updated(self) -> None:
698698
new_a1_value = cm.get_updated(item.assets["analytic"])
699699
self.assertEqual(new_a1_value, set_value)
700700
self.assertEqual(cm.updated, item_value)
701+
702+
703+
class ItemSubClassTest(unittest.TestCase):
704+
"""This tests cases related to creating classes inheriting from pystac.Catalog to
705+
ensure that inheritance, class methods, etc. function as expected."""
706+
707+
SAMPLE_ITEM = TestCases.get_path("data-files/item/sample-item.json")
708+
709+
class BasicCustomItem(pystac.Item):
710+
pass
711+
712+
def setUp(self) -> None:
713+
self.stac_io = pystac.StacIO.default()
714+
715+
def test_from_dict_returns_subclass(self) -> None:
716+
717+
item_dict = self.stac_io.read_json(self.SAMPLE_ITEM)
718+
custom_item = self.BasicCustomItem.from_dict(item_dict)
719+
self.assertIsInstance(custom_item, self.BasicCustomItem)

0 commit comments

Comments
 (0)