Skip to content

Commit 3aac28a

Browse files
committed
Move add_if_missing logic to validate_has_extension
1 parent 77f203e commit 3aac28a

File tree

16 files changed

+75
-117
lines changed

16 files changed

+75
-117
lines changed

pystac/extensions/base.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,26 @@ def has_extension(cls, obj: S) -> bool:
126126
)
127127

128128
@classmethod
129-
def validate_has_extension(cls, obj: Union[S, pystac.Asset]) -> None:
130-
"""Given a :class:`~pystac.STACObject` or :class:`pystac.Asset` instance, checks
131-
if the object (or its owner in the case of an Asset) has this extension's schema
132-
URI in it's :attr:`~pystac.STACObject.stac_extensions` list."""
133-
extensible = obj.owner if isinstance(obj, pystac.Asset) else obj
134-
if (
135-
extensible is not None
136-
and cls.get_schema_uri() not in extensible.stac_extensions
137-
):
129+
def validate_has_extension(
130+
cls, extensible: Optional[S], add_if_missing: bool
131+
) -> None:
132+
"""Given a :class:`~pystac.STACObject`, checks if the object has this
133+
extension's schema URI in it's :attr:`~pystac.STACObject.stac_extensions` list.
134+
If ``add_if_missing`` is ``True``, the schema URI will be added to the object.
135+
136+
Args:
137+
extensible : The object to validate.
138+
add_if_missing : Whether to add the schema URI to the object if the URI is
139+
not already present.
140+
141+
"""
142+
if add_if_missing:
143+
cls.add_to(extensible)
144+
145+
if extensible is None:
146+
return
147+
148+
if cls.get_schema_uri() not in extensible.stac_extensions:
138149
raise pystac.ExtensionNotImplemented(
139150
f"Could not find extension schema URI {cls.get_schema_uri()} in object."
140151
)

pystac/extensions/datacube.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,19 +340,13 @@ def get_schema_uri(cls) -> str:
340340
@classmethod
341341
def ext(cls, obj: T, add_if_missing: bool = False) -> "DatacubeExtension[T]":
342342
if isinstance(obj, pystac.Collection):
343-
if add_if_missing:
344-
cls.add_to(obj)
345-
cls.validate_has_extension(obj)
343+
cls.validate_has_extension(obj, add_if_missing)
346344
return cast(DatacubeExtension[T], CollectionDatacubeExtension(obj))
347345
if isinstance(obj, pystac.Item):
348-
if add_if_missing:
349-
cls.add_to(obj)
350-
cls.validate_has_extension(obj)
346+
cls.validate_has_extension(obj, add_if_missing)
351347
return cast(DatacubeExtension[T], ItemDatacubeExtension(obj))
352348
elif isinstance(obj, pystac.Asset):
353-
if add_if_missing and obj.owner is not None:
354-
cls.add_to(obj.owner)
355-
cls.validate_has_extension(obj)
349+
cls.validate_has_extension(obj.owner, add_if_missing)
356350
return cast(DatacubeExtension[T], AssetDatacubeExtension(obj))
357351
else:
358352
raise pystac.ExtensionTypeError(

pystac/extensions/eo.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -361,14 +361,10 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> "EOExtension[T]":
361361
pystac.ExtensionTypeError : If an invalid object type is passed.
362362
"""
363363
if isinstance(obj, pystac.Item):
364-
if add_if_missing:
365-
cls.add_to(obj)
366-
cls.validate_has_extension(obj)
364+
cls.validate_has_extension(obj, add_if_missing)
367365
return cast(EOExtension[T], ItemEOExtension(obj))
368366
elif isinstance(obj, pystac.Asset):
369-
if add_if_missing and isinstance(obj.owner, pystac.Item):
370-
cls.add_to(obj.owner)
371-
cls.validate_has_extension(obj)
367+
cls.validate_has_extension(obj.owner, add_if_missing)
372368
return cast(EOExtension[T], AssetEOExtension(obj))
373369
else:
374370
raise pystac.ExtensionTypeError(
@@ -380,10 +376,7 @@ def summaries(
380376
cls, obj: pystac.Collection, add_if_missing: bool = False
381377
) -> "SummariesEOExtension":
382378
"""Returns the extended summaries object for the given collection."""
383-
if not add_if_missing:
384-
cls.validate_has_extension(obj)
385-
else:
386-
cls.add_to(obj)
379+
cls.validate_has_extension(obj, add_if_missing)
387380
return SummariesEOExtension(obj)
388381

389382

pystac/extensions/file.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from enum import Enum
7-
from typing import Any, Dict, List, Optional
7+
from typing import Any, Dict, List, Optional, Union
88

99
import pystac
1010
from pystac.extensions.base import ExtensionManagementMixin, PropertiesExtension
@@ -85,7 +85,9 @@ def summary(self, v: str) -> None:
8585
self.properties["summary"] = v
8686

8787

88-
class FileExtension(PropertiesExtension, ExtensionManagementMixin[pystac.Item]):
88+
class FileExtension(
89+
PropertiesExtension, ExtensionManagementMixin[Union[pystac.Item, pystac.Collection]]
90+
):
8991
"""A class that can be used to extend the properties of an :class:`~pystac.Asset`
9092
with properties from the :stac-ext:`File Info Extension <file>`.
9193
@@ -197,9 +199,7 @@ def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> "FileExtension"
197199
This extension can be applied to instances of :class:`~pystac.Asset`.
198200
"""
199201
if isinstance(obj, pystac.Asset):
200-
if add_if_missing and isinstance(obj.owner, pystac.Item):
201-
cls.add_to(obj.owner)
202-
cls.validate_has_extension(obj)
202+
cls.validate_has_extension(obj.owner, add_if_missing)
203203
return cls(obj)
204204
else:
205205
raise pystac.ExtensionTypeError(

pystac/extensions/item_assets.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ def ext(
119119
cls, obj: pystac.Collection, add_if_missing: bool = False
120120
) -> "ItemAssetsExtension":
121121
if isinstance(obj, pystac.Collection):
122-
if add_if_missing:
123-
cls.add_to(obj)
122+
cls.validate_has_extension(obj, add_if_missing)
124123
return cls(obj)
125124
else:
126125
raise pystac.ExtensionTypeError(

pystac/extensions/label.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -691,9 +691,7 @@ def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> "LabelExtension"
691691
This extension can be applied to instances of :class:`~pystac.Item`.
692692
"""
693693
if isinstance(obj, pystac.Item):
694-
if add_if_missing:
695-
cls.add_to(obj)
696-
cls.validate_has_extension(obj)
694+
cls.validate_has_extension(obj, add_if_missing)
697695
return cls(obj)
698696
else:
699697
raise pystac.ExtensionTypeError(
@@ -705,10 +703,7 @@ def summaries(
705703
cls, obj: pystac.Collection, add_if_missing: bool = False
706704
) -> "SummariesLabelExtension":
707705
"""Returns the extended summaries object for the given collection."""
708-
if not add_if_missing:
709-
cls.validate_has_extension(obj)
710-
else:
711-
cls.add_to(obj)
706+
cls.validate_has_extension(obj, add_if_missing)
712707
return SummariesLabelExtension(obj)
713708

714709

pystac/extensions/pointcloud.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -525,14 +525,15 @@ def get_schema_uri(cls) -> str:
525525
@classmethod
526526
def ext(cls, obj: T, add_if_missing: bool = False) -> "PointcloudExtension[T]":
527527
if isinstance(obj, pystac.Item):
528-
if add_if_missing:
529-
cls.add_to(obj)
530-
cls.validate_has_extension(obj)
528+
cls.validate_has_extension(obj, add_if_missing)
531529
return cast(PointcloudExtension[T], ItemPointcloudExtension(obj))
532530
elif isinstance(obj, pystac.Asset):
533-
if add_if_missing and isinstance(obj.owner, pystac.Item):
534-
cls.add_to(obj.owner)
535-
cls.validate_has_extension(obj)
531+
if obj.owner is not None and not isinstance(obj.owner, pystac.Item):
532+
raise pystac.ExtensionTypeError(
533+
"Pointcloud extension does not apply to Assets owned by anything "
534+
"other than an Item."
535+
)
536+
cls.validate_has_extension(obj.owner, add_if_missing)
536537
return cast(PointcloudExtension[T], AssetPointcloudExtension(obj))
537538
else:
538539
raise pystac.ExtensionTypeError(

pystac/extensions/projection.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,10 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> "ProjectionExtension[T]":
272272
pystac.ExtensionTypeError : If an invalid object type is passed.
273273
"""
274274
if isinstance(obj, pystac.Item):
275-
if add_if_missing:
276-
cls.add_to(obj)
277-
cls.validate_has_extension(obj)
275+
cls.validate_has_extension(obj, add_if_missing)
278276
return cast(ProjectionExtension[T], ItemProjectionExtension(obj))
279277
elif isinstance(obj, pystac.Asset):
280-
if add_if_missing and isinstance(obj.owner, pystac.Item):
281-
cls.add_to(obj.owner)
282-
cls.validate_has_extension(obj)
278+
cls.validate_has_extension(obj.owner, add_if_missing)
283279
return cast(ProjectionExtension[T], AssetProjectionExtension(obj))
284280
else:
285281
raise pystac.ExtensionTypeError(
@@ -291,10 +287,7 @@ def summaries(
291287
cls, obj: pystac.Collection, add_if_missing: bool = False
292288
) -> "SummariesProjectionExtension":
293289
"""Returns the extended summaries object for the given collection."""
294-
if not add_if_missing:
295-
cls.validate_has_extension(obj)
296-
else:
297-
cls.add_to(obj)
290+
cls.validate_has_extension(obj, add_if_missing)
298291
return SummariesProjectionExtension(obj)
299292

300293

pystac/extensions/raster.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,7 @@ def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> "RasterExtensio
704704
pystac.ExtensionTypeError : If an invalid object type is passed.
705705
"""
706706
if isinstance(obj, pystac.Asset):
707-
if add_if_missing and isinstance(obj.owner, pystac.Item):
708-
cls.add_to(obj.owner)
709-
cls.validate_has_extension(obj)
707+
cls.validate_has_extension(obj.owner, add_if_missing)
710708
return cls(obj)
711709
else:
712710
raise pystac.ExtensionTypeError(
@@ -717,10 +715,7 @@ def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> "RasterExtensio
717715
def summaries(
718716
cls, obj: pystac.Collection, add_if_missing: bool = False
719717
) -> "SummariesRasterExtension":
720-
if not add_if_missing:
721-
cls.validate_has_extension(obj)
722-
else:
723-
cls.add_to(obj)
718+
cls.validate_has_extension(obj, add_if_missing)
724719
return SummariesRasterExtension(obj)
725720

726721

pystac/extensions/sar.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -315,14 +315,15 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> "SarExtension[T]":
315315
pystac.ExtensionTypeError : If an invalid object type is passed.
316316
"""
317317
if isinstance(obj, pystac.Item):
318-
if add_if_missing:
319-
cls.add_to(obj)
320-
cls.validate_has_extension(obj)
318+
cls.validate_has_extension(obj, add_if_missing)
321319
return cast(SarExtension[T], ItemSarExtension(obj))
322320
elif isinstance(obj, pystac.Asset):
323-
if add_if_missing and isinstance(obj.owner, pystac.Item):
324-
cls.add_to(obj.owner)
325-
cls.validate_has_extension(obj)
321+
if obj.owner is not None and not isinstance(obj.owner, pystac.Item):
322+
raise pystac.ExtensionTypeError(
323+
"SAR extension does not apply to Assets owned by anything "
324+
"other than an Item."
325+
)
326+
cls.validate_has_extension(obj.owner, add_if_missing)
326327
return cast(SarExtension[T], AssetSarExtension(obj))
327328
else:
328329
raise pystac.ExtensionTypeError(

0 commit comments

Comments
 (0)