Skip to content

Commit dbb8db9

Browse files
committed
Add summaries for Label Extension
1 parent e371fb2 commit dbb8db9

File tree

3 files changed

+218
-3
lines changed

3 files changed

+218
-3
lines changed

pystac/extensions/label.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
"""
55

66
from enum import Enum
7-
from pystac.extensions.base import ExtensionManagementMixin
7+
from pystac.extensions.base import ExtensionManagementMixin, SummariesExtension
88
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
99

1010
import pystac
1111
from pystac.serialization.identify import STACJSONDescription, STACVersionID
1212
from pystac.extensions.hooks import ExtensionHooks
13-
from pystac.utils import get_required
13+
from pystac.utils import get_required, map_opt
1414

1515
SCHEMA_URI = "https://stac-extensions.github.io/label/v1.0.0/schema.json"
1616

@@ -703,6 +703,83 @@ def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> "LabelExtension"
703703
f"Label extension does not apply to type {type(obj)}"
704704
)
705705

706+
@staticmethod
707+
def summaries(obj: pystac.Collection) -> "SummariesLabelExtension":
708+
"""Returns the extended summaries object for the given collection."""
709+
return SummariesLabelExtension(obj)
710+
711+
712+
class SummariesLabelExtension(SummariesExtension):
713+
"""A concrete implementation of :class:`~SummariesExtension` that extends
714+
the ``summaries`` field of a :class:`~pystac.Collection` to include properties
715+
defined in the :stac-ext:`Label Extension <label>`.
716+
"""
717+
718+
@property
719+
def label_properties(self) -> Optional[List[str]]:
720+
"""Get or sets the summary of :attr:`LabelExtension.label_properties` values
721+
for this Collection.
722+
"""
723+
724+
return self.summaries.get_list(PROPERTIES_PROP)
725+
726+
@label_properties.setter
727+
def label_properties(self, v: Optional[List[LabelClasses]]) -> None:
728+
self._set_summary(PROPERTIES_PROP, v)
729+
730+
@property
731+
def label_classes(self) -> Optional[List[LabelClasses]]:
732+
"""Get or sets the summary of :attr:`LabelExtension.label_classes` values
733+
for this Collection.
734+
"""
735+
736+
return map_opt(
737+
lambda classes: [LabelClasses(c) for c in classes],
738+
self.summaries.get_list(CLASSES_PROP),
739+
)
740+
741+
@label_classes.setter
742+
def label_classes(self, v: Optional[List[LabelClasses]]) -> None:
743+
self._set_summary(
744+
CLASSES_PROP, map_opt(lambda classes: [c.to_dict() for c in classes], v)
745+
)
746+
747+
@property
748+
def label_type(self) -> Optional[List[LabelType]]:
749+
"""Get or sets the summary of :attr:`LabelExtension.label_type` values
750+
for this Collection.
751+
"""
752+
753+
return self.summaries.get_list(TYPE_PROP)
754+
755+
@label_type.setter
756+
def label_type(self, v: Optional[List[LabelType]]) -> None:
757+
self._set_summary(TYPE_PROP, v)
758+
759+
@property
760+
def label_tasks(self) -> Optional[List[Union[LabelTask, str]]]:
761+
"""Get or sets the summary of :attr:`LabelExtension.label_tasks` values
762+
for this Collection.
763+
"""
764+
765+
return self.summaries.get_list(TASKS_PROP)
766+
767+
@label_tasks.setter
768+
def label_tasks(self, v: Optional[List[Union[LabelTask, str]]]) -> None:
769+
self._set_summary(TASKS_PROP, v)
770+
771+
@property
772+
def label_methods(self) -> Optional[List[str]]:
773+
"""Get or sets the summary of :attr:`LabelExtension.label_methods` values
774+
for this Collection.
775+
"""
776+
777+
return self.summaries.get_list(METHODS_PROP)
778+
779+
@label_methods.setter
780+
def label_methods(self, v: Optional[List[str]]) -> None:
781+
self._set_summary(METHODS_PROP, v)
782+
706783

707784
class LabelExtensionHooks(ExtensionHooks):
708785
schema_uri: str = SCHEMA_URI
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
{
2+
"stac_version": "1.0.0-rc.1",
3+
"type": "Collection",
4+
"id": "spacenet-roads-sample",
5+
"description": "A sample of the SpaceNet Roads dataset built during STAC Sprint 4. The dataset contains hand-labeled roads.",
6+
"keywords": [
7+
"spacenet",
8+
"roads",
9+
"labels"
10+
],
11+
"license": "CC-BY-SA-4.0",
12+
"providers": [
13+
{
14+
"name": "SpaceNet",
15+
"roles": [
16+
"licensor",
17+
"host",
18+
"producer",
19+
"processor"
20+
],
21+
"url": "https://spacenet.ai"
22+
}
23+
],
24+
"extent": {
25+
"spatial": {
26+
"bbox": [
27+
[
28+
2.23379639995,
29+
49.0178709,
30+
2.23730639995,
31+
49.0213809
32+
]
33+
]
34+
},
35+
"temporal": {
36+
"interval": [
37+
[
38+
"2016-08-26T22:41:55.000000Z",
39+
null
40+
]
41+
]
42+
}
43+
},
44+
"links": [
45+
{
46+
"href": "roads_collection.json",
47+
"rel": "root",
48+
"title": "sample SpaceNet roads label collection"
49+
},
50+
{
51+
"rel": "item",
52+
"href": "roads_item.json"
53+
}
54+
]
55+
}

tests/extensions/test_label.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from typing import List, Union
66

77
import pystac
8-
from pystac import Catalog, Item, CatalogType
8+
from pystac import Catalog, Collection, Item, CatalogType
99
from pystac.extensions.label import (
1010
LabelExtension,
1111
LabelClasses,
1212
LabelCount,
1313
LabelOverview,
1414
LabelStatistics,
15+
LabelTask,
1516
LabelType,
1617
LabelRelType,
1718
)
@@ -439,3 +440,85 @@ def test_ext_add_to(self) -> None:
439440
_ = LabelExtension.ext(item, add_if_missing=True)
440441

441442
self.assertIn(LabelExtension.get_schema_uri(), item.stac_extensions)
443+
444+
445+
class LabelSummariesTest(unittest.TestCase):
446+
EXAMPLE_COLLECTION = TestCases.get_path(
447+
"data-files/label/spacenet-roads/roads_collection.json"
448+
)
449+
450+
def test_label_properties_summary(self) -> None:
451+
label_properties = ["road_type", "lane_number", "paved"]
452+
collection = Collection.from_file(self.EXAMPLE_COLLECTION)
453+
label_ext_summaries = LabelExtension.summaries(collection)
454+
455+
label_ext_summaries.label_properties = label_properties
456+
457+
summaries = collection.summaries
458+
assert summaries is not None
459+
label_properties_summary = summaries.get_list("label:properties")
460+
assert label_properties_summary is not None
461+
self.assertListEqual(label_properties, label_properties_summary)
462+
463+
label_properties_summary_ext = label_ext_summaries.label_properties
464+
assert label_properties_summary_ext is not None
465+
self.assertListEqual(label_properties, label_properties_summary_ext)
466+
467+
def test_label_classes_summary(self) -> None:
468+
label_classes = [
469+
LabelClasses(
470+
{"name": "road_type", "classes": ["1", "2", "3", "4", "5", "6"]}
471+
),
472+
LabelClasses({"name": "lane_number", "classes": ["1", "2", "3", "4", "5"]}),
473+
LabelClasses({"name": "paved", "classes": ["0", "1"]}),
474+
]
475+
collection = Collection.from_file(self.EXAMPLE_COLLECTION)
476+
label_ext_summaries = LabelExtension.summaries(collection)
477+
478+
label_ext_summaries.label_classes = label_classes
479+
480+
summaries = collection.summaries
481+
assert summaries is not None
482+
label_classes_summary = summaries.get_list("label:classes")
483+
assert label_classes_summary is not None
484+
self.assertListEqual(
485+
[lc.to_dict() for lc in label_classes], label_classes_summary
486+
)
487+
488+
label_classes_summary_ext = label_ext_summaries.label_classes
489+
assert label_classes_summary_ext is not None
490+
self.assertListEqual(label_classes, label_classes_summary_ext)
491+
492+
def test_label_type_summary(self) -> None:
493+
label_types = [LabelType.VECTOR]
494+
collection = Collection.from_file(self.EXAMPLE_COLLECTION)
495+
label_ext_summaries = LabelExtension.summaries(collection)
496+
497+
label_ext_summaries.label_type = label_types
498+
499+
summaries = collection.summaries
500+
assert summaries is not None
501+
label_type_summary = summaries.get_list("label:type")
502+
assert label_type_summary is not None
503+
self.assertListEqual(label_types, label_type_summary)
504+
505+
label_type_summary_ext = label_ext_summaries.label_type
506+
assert label_type_summary_ext is not None
507+
self.assertListEqual(label_types, label_type_summary_ext)
508+
509+
def test_label_task_summary(self) -> None:
510+
label_tasks: List[Union[LabelTask, str]] = [LabelTask.REGRESSION]
511+
collection = Collection.from_file(self.EXAMPLE_COLLECTION)
512+
label_ext_summaries = LabelExtension.summaries(collection)
513+
514+
label_ext_summaries.label_tasks = label_tasks
515+
516+
summaries = collection.summaries
517+
assert summaries is not None
518+
label_tasks_summary = summaries.get_list("label:tasks")
519+
assert label_tasks_summary is not None
520+
self.assertListEqual(label_tasks, label_tasks_summary)
521+
522+
label_tasks_summary_ext = label_ext_summaries.label_tasks
523+
assert label_tasks_summary_ext is not None
524+
self.assertListEqual(label_tasks, label_tasks_summary_ext)

0 commit comments

Comments
 (0)