Skip to content

Commit b173ede

Browse files
committed
Handle extra fields in extent classes
1 parent 5a42904 commit b173ede

File tree

2 files changed

+149
-37
lines changed

2 files changed

+149
-37
lines changed

pystac/collection.py

Lines changed: 103 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,28 +46,41 @@ class SpatialExtent:
4646
array must be 2*n where n is the number of dimensions. For example, a
4747
2D Collection with only one bbox would be [[xmin, ymin, xmax, ymax]]
4848
49-
Attributes:
50-
bboxes : A list of bboxes that represent the spatial
51-
extent of the collection. Each bbox can be 2D or 3D. The length of the bbox
52-
array must be 2*n where n is the number of dimensions. For example, a
53-
2D Collection with only one bbox would be [[xmin, ymin, xmax, ymax]]
49+
extra_fields : Dictionary containing additional top-level fields defined on the
50+
Spatial Extent object.
5451
"""
5552

56-
def __init__(self, bboxes: Union[List[List[float]], List[float]]) -> None:
53+
bboxes: List[List[float]]
54+
"""A list of bboxes that represent the spatial
55+
extent of the collection. Each bbox can be 2D or 3D. The length of the bbox
56+
array must be 2*n where n is the number of dimensions. For example, a
57+
2D Collection with only one bbox would be [[xmin, ymin, xmax, ymax]]"""
58+
59+
extra_fields: Dict[str, Any]
60+
"""Dictionary containing additional top-level fields defined on the Spatial
61+
Extent object."""
62+
63+
def __init__(
64+
self,
65+
bboxes: Union[List[List[float]], List[float]],
66+
extra_fields: Optional[Dict[str, Any]] = None,
67+
) -> None:
5768
# A common mistake is to pass in a single bbox instead of a list of bboxes.
5869
# Account for this by transforming the input in that case.
5970
if isinstance(bboxes, list) and isinstance(bboxes[0], float):
6071
self.bboxes: List[List[float]] = [cast(List[float], bboxes)]
6172
else:
6273
self.bboxes = cast(List[List[float]], bboxes)
6374

75+
self.extra_fields = extra_fields or {}
76+
6477
def to_dict(self) -> Dict[str, Any]:
6578
"""Generate a dictionary representing the JSON of this SpatialExtent.
6679
6780
Returns:
6881
dict: A serialization of the SpatialExtent that can be written out as JSON.
6982
"""
70-
d = {"bbox": self.bboxes}
83+
d = {"bbox": self.bboxes, **self.extra_fields}
7184
return d
7285

7386
def clone(self) -> "SpatialExtent":
@@ -76,26 +89,34 @@ def clone(self) -> "SpatialExtent":
7689
Returns:
7790
SpatialExtent: The clone of this object.
7891
"""
79-
return SpatialExtent(deepcopy(self.bboxes))
92+
return SpatialExtent(
93+
bboxes=deepcopy(self.bboxes), extra_fields=deepcopy(self.extra_fields)
94+
)
8095

8196
@staticmethod
8297
def from_dict(d: Dict[str, Any]) -> "SpatialExtent":
83-
"""Constructs an SpatialExtent from a dict.
98+
"""Constructs a SpatialExtent from a dict.
8499
85100
Returns:
86101
SpatialExtent: The SpatialExtent deserialized from the JSON dict.
87102
"""
88-
return SpatialExtent(bboxes=d["bbox"])
103+
return SpatialExtent(
104+
bboxes=d["bbox"], extra_fields={k: v for k, v in d.items() if k != "bbox"}
105+
)
89106

90107
@staticmethod
91-
def from_coordinates(coordinates: List[Any]) -> "SpatialExtent":
108+
def from_coordinates(
109+
coordinates: List[Any], extra_fields: Optional[Dict[str, Any]] = None
110+
) -> "SpatialExtent":
92111
"""Constructs a SpatialExtent from a set of coordinates.
93112
94113
This method will only produce a single bbox that covers all points
95114
in the coordinate set.
96115
97116
Args:
98117
coordinates : Coordinates to derive the bbox from.
118+
extra_fields : Dictionary containing additional top-level fields defined on
119+
the Spatial Extent object.
99120
100121
Returns:
101122
SpatialExtent: A SpatialExtent with a single bbox that covers the
@@ -133,31 +154,40 @@ def process_coords(
133154
f"Could not determine bounds from coordinate sequence {coordinates}"
134155
)
135156

136-
return SpatialExtent([[xmin, ymin, xmax, ymax]])
157+
return SpatialExtent(
158+
bboxes=[[xmin, ymin, xmax, ymax]], extra_fields=extra_fields
159+
)
137160

138161

139162
class TemporalExtent:
140163
"""Describes the temporal extent of a Collection.
141164
142165
Args:
143166
intervals : A list of two datetimes wrapped in a list,
144-
representing the temporal extent of a Collection. Open date ranges are supported
145-
by setting either the start (the first element of the interval) or the end (the
146-
second element of the interval) to None.
147-
148-
149-
Attributes:
150-
intervals : A list of two datetimes wrapped in a list,
151-
representing the temporal extent of a Collection. Open date ranges are
152-
represented by either the start (the first element of the interval) or the
153-
end (the second element of the interval) being None.
167+
representing the temporal extent of a Collection. Open date ranges are
168+
supported by setting either the start (the first element of the interval)
169+
or the end (the second element of the interval) to None.
154170
171+
extra_fields : Dictionary containing additional top-level fields defined on the
172+
Temporal Extent object.
155173
Note:
156174
Datetimes are required to be in UTC.
157175
"""
158176

177+
intervals: List[List[Optional[Datetime]]]
178+
"""A list of two datetimes wrapped in a list,
179+
representing the temporal extent of a Collection. Open date ranges are
180+
represented by either the start (the first element of the interval) or the
181+
end (the second element of the interval) being None."""
182+
183+
extra_fields: Dict[str, Any]
184+
"""Dictionary containing additional top-level fields defined on the Temporal
185+
Extent object."""
186+
159187
def __init__(
160-
self, intervals: Union[List[List[Optional[Datetime]]], List[Optional[Datetime]]]
188+
self,
189+
intervals: Union[List[List[Optional[Datetime]]], List[Optional[Datetime]]],
190+
extra_fields: Optional[Dict[str, Any]] = None,
161191
):
162192
# A common mistake is to pass in a single interval instead of a
163193
# list of intervals. Account for this by transforming the input
@@ -167,6 +197,8 @@ def __init__(
167197
else:
168198
self.intervals = cast(List[List[Optional[Datetime]]], intervals)
169199

200+
self.extra_fields = extra_fields or {}
201+
170202
def to_dict(self) -> Dict[str, Any]:
171203
"""Generate a dictionary representing the JSON of this TemporalExtent.
172204
@@ -186,7 +218,7 @@ def to_dict(self) -> Dict[str, Any]:
186218

187219
encoded_intervals.append([start, end])
188220

189-
d = {"interval": encoded_intervals}
221+
d = {"interval": encoded_intervals, **self.extra_fields}
190222
return d
191223

192224
def clone(self) -> "TemporalExtent":
@@ -195,7 +227,9 @@ def clone(self) -> "TemporalExtent":
195227
Returns:
196228
TemporalExtent: The clone of this object.
197229
"""
198-
return TemporalExtent(intervals=deepcopy(self.intervals))
230+
return TemporalExtent(
231+
intervals=deepcopy(self.intervals), extra_fields=deepcopy(self.extra_fields)
232+
)
199233

200234
@staticmethod
201235
def from_dict(d: Dict[str, Any]) -> "TemporalExtent":
@@ -215,7 +249,10 @@ def from_dict(d: Dict[str, Any]) -> "TemporalExtent":
215249
end = dateutil.parser.parse(i[1])
216250
parsed_intervals.append([start, end])
217251

218-
return TemporalExtent(intervals=parsed_intervals)
252+
return TemporalExtent(
253+
intervals=parsed_intervals,
254+
extra_fields={k: v for k, v in d.items() if k != "interval"},
255+
)
219256

220257
@staticmethod
221258
def from_now() -> "TemporalExtent":
@@ -236,23 +273,41 @@ class Extent:
236273
Args:
237274
spatial : Potential spatial extent covered by the collection.
238275
temporal : Potential temporal extent covered by the collection.
239-
240-
Attributes:
241-
spatial : Potential spatial extent covered by the collection.
242-
temporal : Potential temporal extent covered by the collection.
276+
extra_fields : Dictionary containing additional top-level fields defined on the
277+
Extent object.
243278
"""
244279

245-
def __init__(self, spatial: SpatialExtent, temporal: TemporalExtent):
280+
spatial: SpatialExtent
281+
"""Potential spatial extent covered by the collection."""
282+
283+
temporal: TemporalExtent
284+
"""Potential temporal extent covered by the collection."""
285+
286+
extra_fields: Dict[str, Any]
287+
"""Dictionary containing additional top-level fields defined on the Extent
288+
object."""
289+
290+
def __init__(
291+
self,
292+
spatial: SpatialExtent,
293+
temporal: TemporalExtent,
294+
extra_fields: Optional[Dict[str, Any]] = None,
295+
):
246296
self.spatial = spatial
247297
self.temporal = temporal
298+
self.extra_fields = extra_fields or {}
248299

249300
def to_dict(self) -> Dict[str, Any]:
250301
"""Generate a dictionary representing the JSON of this Extent.
251302
252303
Returns:
253304
dict: A serialization of the Extent that can be written out as JSON.
254305
"""
255-
d = {"spatial": self.spatial.to_dict(), "temporal": self.temporal.to_dict()}
306+
d = {
307+
"spatial": self.spatial.to_dict(),
308+
"temporal": self.temporal.to_dict(),
309+
**self.extra_fields,
310+
}
256311

257312
return d
258313

@@ -262,7 +317,11 @@ def clone(self) -> "Extent":
262317
Returns:
263318
Extent: The clone of this extent.
264319
"""
265-
return Extent(spatial=copy(self.spatial), temporal=copy(self.temporal))
320+
return Extent(
321+
spatial=copy(self.spatial),
322+
temporal=copy(self.temporal),
323+
extra_fields=deepcopy(self.extra_fields),
324+
)
266325

267326
@staticmethod
268327
def from_dict(d: Dict[str, Any]) -> "Extent":
@@ -287,16 +346,23 @@ def from_dict(d: Dict[str, Any]) -> "Extent":
287346
temporal_extent_dict = temporal_extent
288347

289348
return Extent(
290-
SpatialExtent.from_dict(spatial_extent_dict),
291-
TemporalExtent.from_dict(temporal_extent_dict),
349+
spatial=SpatialExtent.from_dict(spatial_extent_dict),
350+
temporal=TemporalExtent.from_dict(temporal_extent_dict),
351+
extra_fields={
352+
k: v for k, v in d.items() if k not in {"spatial", "temporal"}
353+
},
292354
)
293355

294356
@staticmethod
295-
def from_items(items: Iterable["Item_Type"]) -> "Extent":
357+
def from_items(
358+
items: Iterable["Item_Type"], extra_fields: Optional[Dict[str, Any]] = None
359+
) -> "Extent":
296360
"""Create an Extent based on the datetimes and bboxes of a list of items.
297361
298362
Args:
299363
items : A list of items to derive the extent from.
364+
extra_fields : Optional dictionary containing additional top-level fields
365+
defined on the Extent object.
300366
301367
Returns:
302368
Extent: An Extent that spatially and temporally covers all of the
@@ -354,7 +420,7 @@ def from_items(items: Iterable["Item_Type"]) -> "Extent":
354420
)
355421
temporal = TemporalExtent([[start_timestamp, end_timestamp]])
356422

357-
return Extent(spatial, temporal)
423+
return Extent(spatial=spatial, temporal=temporal, extra_fields=extra_fields)
358424

359425

360426
class Provider:

tests/test_collection.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ def test_from_invalid_dict_raises_exception(self) -> None:
252252

253253

254254
class ExtentTest(unittest.TestCase):
255+
def setUp(self) -> None:
256+
self.maxDiff = None
257+
255258
def test_spatial_allows_single_bbox(self) -> None:
256259
temporal_extent = TemporalExtent(intervals=[[TEST_DATETIME, None]])
257260

@@ -319,6 +322,49 @@ def test_from_items(self) -> None:
319322
self.assertEqual(interval[0], datetime(2000, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC))
320323
self.assertEqual(interval[1], datetime(2001, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC))
321324

325+
def test_to_from_dict(self) -> None:
326+
spatial_dict = {
327+
"bbox": [
328+
[
329+
172.91173669923782,
330+
1.3438851951615003,
331+
172.95469614953714,
332+
1.3690476620161975,
333+
]
334+
],
335+
"extension:field": "spatial value",
336+
}
337+
temporal_dict = {
338+
"interval": [
339+
["2020-12-11T22:38:32.125000Z", "2020-12-14T18:02:31.437000Z"]
340+
],
341+
"extension:field": "temporal value",
342+
}
343+
extent_dict = {
344+
"spatial": spatial_dict,
345+
"temporal": temporal_dict,
346+
"extension:field": "extent value",
347+
}
348+
expected_extent_extra_fields = {
349+
"extension:field": extent_dict["extension:field"],
350+
}
351+
expected_spatial_extra_fields = {
352+
"extension:field": spatial_dict["extension:field"],
353+
}
354+
expected_temporal_extra_fields = {
355+
"extension:field": temporal_dict["extension:field"],
356+
}
357+
358+
extent = Extent.from_dict(extent_dict)
359+
360+
self.assertDictEqual(expected_extent_extra_fields, extent.extra_fields)
361+
self.assertDictEqual(expected_spatial_extra_fields, extent.spatial.extra_fields)
362+
self.assertDictEqual(
363+
expected_temporal_extra_fields, extent.temporal.extra_fields
364+
)
365+
366+
self.assertDictEqual(extent_dict, extent.to_dict())
367+
322368

323369
class CollectionSubClassTest(unittest.TestCase):
324370
"""This tests cases related to creating classes inheriting from pystac.Catalog to

0 commit comments

Comments
 (0)