|
5 | 5 | from typing import List, Union
|
6 | 6 |
|
7 | 7 | import pystac
|
8 |
| -from pystac import Catalog, Item, CatalogType |
| 8 | +from pystac import Catalog, Collection, Item, CatalogType |
9 | 9 | from pystac.extensions.label import (
|
10 | 10 | LabelExtension,
|
11 | 11 | LabelClasses,
|
12 | 12 | LabelCount,
|
13 | 13 | LabelOverview,
|
14 | 14 | LabelStatistics,
|
| 15 | + LabelTask, |
15 | 16 | LabelType,
|
16 | 17 | LabelRelType,
|
17 | 18 | )
|
@@ -439,3 +440,85 @@ def test_ext_add_to(self) -> None:
|
439 | 440 | _ = LabelExtension.ext(item, add_if_missing=True)
|
440 | 441 |
|
441 | 442 | self.assertIn(LabelExtension.get_schema_uri(), item.stac_extensions)
|
| 443 | + |
| 444 | + |
| 445 | +class LabelSummariesTest(unittest.TestCase): |
| 446 | + EXAMPLE_COLLECTION = TestCases.get_path( |
| 447 | + "data-files/label/spacenet-roads/roads_collection.json" |
| 448 | + ) |
| 449 | + |
| 450 | + def test_label_properties_summary(self) -> None: |
| 451 | + label_properties = ["road_type", "lane_number", "paved"] |
| 452 | + collection = Collection.from_file(self.EXAMPLE_COLLECTION) |
| 453 | + label_ext_summaries = LabelExtension.summaries(collection) |
| 454 | + |
| 455 | + label_ext_summaries.label_properties = label_properties |
| 456 | + |
| 457 | + summaries = collection.summaries |
| 458 | + assert summaries is not None |
| 459 | + label_properties_summary = summaries.get_list("label:properties") |
| 460 | + assert label_properties_summary is not None |
| 461 | + self.assertListEqual(label_properties, label_properties_summary) |
| 462 | + |
| 463 | + label_properties_summary_ext = label_ext_summaries.label_properties |
| 464 | + assert label_properties_summary_ext is not None |
| 465 | + self.assertListEqual(label_properties, label_properties_summary_ext) |
| 466 | + |
| 467 | + def test_label_classes_summary(self) -> None: |
| 468 | + label_classes = [ |
| 469 | + LabelClasses( |
| 470 | + {"name": "road_type", "classes": ["1", "2", "3", "4", "5", "6"]} |
| 471 | + ), |
| 472 | + LabelClasses({"name": "lane_number", "classes": ["1", "2", "3", "4", "5"]}), |
| 473 | + LabelClasses({"name": "paved", "classes": ["0", "1"]}), |
| 474 | + ] |
| 475 | + collection = Collection.from_file(self.EXAMPLE_COLLECTION) |
| 476 | + label_ext_summaries = LabelExtension.summaries(collection) |
| 477 | + |
| 478 | + label_ext_summaries.label_classes = label_classes |
| 479 | + |
| 480 | + summaries = collection.summaries |
| 481 | + assert summaries is not None |
| 482 | + label_classes_summary = summaries.get_list("label:classes") |
| 483 | + assert label_classes_summary is not None |
| 484 | + self.assertListEqual( |
| 485 | + [lc.to_dict() for lc in label_classes], label_classes_summary |
| 486 | + ) |
| 487 | + |
| 488 | + label_classes_summary_ext = label_ext_summaries.label_classes |
| 489 | + assert label_classes_summary_ext is not None |
| 490 | + self.assertListEqual(label_classes, label_classes_summary_ext) |
| 491 | + |
| 492 | + def test_label_type_summary(self) -> None: |
| 493 | + label_types = [LabelType.VECTOR] |
| 494 | + collection = Collection.from_file(self.EXAMPLE_COLLECTION) |
| 495 | + label_ext_summaries = LabelExtension.summaries(collection) |
| 496 | + |
| 497 | + label_ext_summaries.label_type = label_types |
| 498 | + |
| 499 | + summaries = collection.summaries |
| 500 | + assert summaries is not None |
| 501 | + label_type_summary = summaries.get_list("label:type") |
| 502 | + assert label_type_summary is not None |
| 503 | + self.assertListEqual(label_types, label_type_summary) |
| 504 | + |
| 505 | + label_type_summary_ext = label_ext_summaries.label_type |
| 506 | + assert label_type_summary_ext is not None |
| 507 | + self.assertListEqual(label_types, label_type_summary_ext) |
| 508 | + |
| 509 | + def test_label_task_summary(self) -> None: |
| 510 | + label_tasks: List[Union[LabelTask, str]] = [LabelTask.REGRESSION] |
| 511 | + collection = Collection.from_file(self.EXAMPLE_COLLECTION) |
| 512 | + label_ext_summaries = LabelExtension.summaries(collection) |
| 513 | + |
| 514 | + label_ext_summaries.label_tasks = label_tasks |
| 515 | + |
| 516 | + summaries = collection.summaries |
| 517 | + assert summaries is not None |
| 518 | + label_tasks_summary = summaries.get_list("label:tasks") |
| 519 | + assert label_tasks_summary is not None |
| 520 | + self.assertListEqual(label_tasks, label_tasks_summary) |
| 521 | + |
| 522 | + label_tasks_summary_ext = label_ext_summaries.label_tasks |
| 523 | + assert label_tasks_summary_ext is not None |
| 524 | + self.assertListEqual(label_tasks, label_tasks_summary_ext) |
0 commit comments