Skip to content

Commit 3ae8d86

Browse files
add from_extensions class method to create CollectionSearch extensions classes (#745)
* add from_extensions class method to create CollectionSearch extensions classes * Apply suggestions from code review * Apply suggestions from code review * fix * update makefile
1 parent cf55d66 commit 3ae8d86

File tree

3 files changed

+223
-2
lines changed

3 files changed

+223
-2
lines changed

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## [Unreleased]
44

5+
### Added
6+
7+
* Add `from_extensions()` method to `CollectionSearchExtension` and `CollectionSearchPostExtension` extensions to build the class based on a list of available extensions.
8+
59
## [3.0.1] - 2024-08-27
610

711
### Changed

stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/collection_search.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Collection-Search extension."""
22

3+
import warnings
34
from enum import Enum
45
from typing import List, Optional, Union
56

@@ -8,7 +9,7 @@
89
from stac_pydantic.api.collections import Collections
910
from stac_pydantic.shared import MimeTypes
1011

11-
from stac_fastapi.api.models import GeoJSONResponse
12+
from stac_fastapi.api.models import GeoJSONResponse, create_request_model
1213
from stac_fastapi.api.routes import create_async_endpoint
1314
from stac_fastapi.types.config import ApiSettings
1415
from stac_fastapi.types.extension import ApiExtension
@@ -71,6 +72,48 @@ def register(self, app: FastAPI) -> None:
7172
"""
7273
pass
7374

75+
@classmethod
76+
def from_extensions(
77+
cls,
78+
extensions: List[ApiExtension],
79+
schema_href: Optional[str] = None,
80+
) -> "CollectionSearchExtension":
81+
"""Create CollectionSearchExtension object from extensions."""
82+
supported_extensions = {
83+
"FreeTextExtension": ConformanceClasses.FREETEXT,
84+
"FreeTextAdvancedExtension": ConformanceClasses.FREETEXT,
85+
"QueryExtension": ConformanceClasses.QUERY,
86+
"SortExtension": ConformanceClasses.SORT,
87+
"FieldsExtension": ConformanceClasses.FIELDS,
88+
"FilterExtension": ConformanceClasses.FILTER,
89+
}
90+
conformance_classes = [
91+
ConformanceClasses.COLLECTIONSEARCH,
92+
ConformanceClasses.BASIS,
93+
]
94+
for ext in extensions:
95+
conf = supported_extensions.get(ext.__class__.__name__, None)
96+
if not conf:
97+
warnings.warn(
98+
f"Conformance class for `{ext.__class__.__name__}` extension not found.", # noqa: E501
99+
UserWarning,
100+
)
101+
else:
102+
conformance_classes.append(conf)
103+
104+
get_request_model = create_request_model(
105+
model_name="CollectionsGetRequest",
106+
base_model=BaseCollectionSearchGetRequest,
107+
extensions=extensions,
108+
request_type="GET",
109+
)
110+
111+
return cls(
112+
GET=get_request_model,
113+
conformance_classes=conformance_classes,
114+
schema_href=schema_href,
115+
)
116+
74117

75118
@attr.s
76119
class CollectionSearchPostExtension(CollectionSearchExtension):
@@ -132,3 +175,60 @@ def register(self, app: FastAPI) -> None:
132175
endpoint=create_async_endpoint(self.client.post_all_collections, self.POST),
133176
)
134177
app.include_router(self.router)
178+
179+
@classmethod
180+
def from_extensions(
181+
cls,
182+
extensions: List[ApiExtension],
183+
*,
184+
client: Union[AsyncBaseCollectionSearchClient, BaseCollectionSearchClient],
185+
settings: ApiSettings,
186+
schema_href: Optional[str] = None,
187+
router: Optional[APIRouter] = None,
188+
) -> "CollectionSearchPostExtension":
189+
"""Create CollectionSearchPostExtension object from extensions."""
190+
supported_extensions = {
191+
"FreeTextExtension": ConformanceClasses.FREETEXT,
192+
"FreeTextAdvancedExtension": ConformanceClasses.FREETEXT,
193+
"QueryExtension": ConformanceClasses.QUERY,
194+
"SortExtension": ConformanceClasses.SORT,
195+
"FieldsExtension": ConformanceClasses.FIELDS,
196+
"FilterExtension": ConformanceClasses.FILTER,
197+
}
198+
conformance_classes = [
199+
ConformanceClasses.COLLECTIONSEARCH,
200+
ConformanceClasses.BASIS,
201+
]
202+
for ext in extensions:
203+
conf = supported_extensions.get(ext.__class__.__name__, None)
204+
if not conf:
205+
warnings.warn(
206+
f"Conformance class for `{ext.__class__.__name__}` extension not found.", # noqa: E501
207+
UserWarning,
208+
)
209+
else:
210+
conformance_classes.append(conf)
211+
212+
get_request_model = create_request_model(
213+
model_name="CollectionsGetRequest",
214+
base_model=BaseCollectionSearchGetRequest,
215+
extensions=extensions,
216+
request_type="GET",
217+
)
218+
219+
post_request_model = create_request_model(
220+
model_name="CollectionsPostRequest",
221+
base_model=BaseCollectionSearchPostRequest,
222+
extensions=extensions,
223+
request_type="POST",
224+
)
225+
226+
return cls(
227+
client=client,
228+
settings=settings,
229+
GET=get_request_model,
230+
POST=post_request_model,
231+
conformance_classes=conformance_classes,
232+
router=router or APIRouter(),
233+
schema_href=schema_href,
234+
)

stac_fastapi/extensions/tests/test_collection_search.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,21 @@
22
from urllib.parse import quote_plus
33

44
import attr
5+
import pytest
56
from starlette.testclient import TestClient
67

78
from stac_fastapi.api.app import StacApi
89
from stac_fastapi.api.models import create_request_model
910
from stac_fastapi.extensions.core import (
11+
AggregationExtension,
1012
CollectionSearchExtension,
1113
CollectionSearchPostExtension,
14+
FieldsExtension,
15+
FilterExtension,
16+
FreeTextAdvancedExtension,
17+
FreeTextExtension,
18+
QueryExtension,
19+
SortExtension,
1220
)
1321
from stac_fastapi.extensions.core.collection_search import ConformanceClasses
1422
from stac_fastapi.extensions.core.collection_search.client import (
@@ -302,8 +310,8 @@ def test_collection_search_extension_post_models():
302310
client=DummyCoreClient(),
303311
extensions=[
304312
CollectionSearchPostExtension(
305-
settings=settings,
306313
client=DummyPostClient(),
314+
settings=settings,
307315
GET=get_request_model,
308316
POST=post_request_model,
309317
conformance_classes=[
@@ -392,3 +400,112 @@ def test_collection_search_extension_post_models():
392400
assert response_dict["query"]
393401
assert response_dict["sortby"]
394402
assert response_dict["fields"]
403+
404+
405+
@pytest.mark.parametrize(
406+
"extensions",
407+
[
408+
# with FreeTextExtension
409+
[
410+
FieldsExtension(),
411+
FilterExtension(),
412+
FreeTextExtension(),
413+
QueryExtension(),
414+
SortExtension(),
415+
],
416+
# with FreeTextAdvancedExtension
417+
[
418+
FieldsExtension(),
419+
FilterExtension(),
420+
FreeTextAdvancedExtension(),
421+
QueryExtension(),
422+
SortExtension(),
423+
],
424+
],
425+
)
426+
def test_from_extensions_methods(extensions):
427+
"""
428+
Make sure `from_extensions` create the correct
429+
models and adds desired conformances classes.
430+
"""
431+
ext = CollectionSearchExtension.from_extensions(
432+
extensions,
433+
)
434+
collection_search = ext.GET()
435+
assert collection_search.__class__.__name__ == "CollectionsGetRequest"
436+
assert hasattr(collection_search, "bbox")
437+
assert hasattr(collection_search, "datetime")
438+
assert hasattr(collection_search, "limit")
439+
assert hasattr(collection_search, "fields")
440+
assert hasattr(collection_search, "q")
441+
assert hasattr(collection_search, "sortby")
442+
assert hasattr(collection_search, "filter")
443+
assert ext.conformance_classes == [
444+
ConformanceClasses.COLLECTIONSEARCH,
445+
ConformanceClasses.BASIS,
446+
ConformanceClasses.FIELDS,
447+
ConformanceClasses.FILTER,
448+
ConformanceClasses.FREETEXT,
449+
ConformanceClasses.QUERY,
450+
ConformanceClasses.SORT,
451+
]
452+
453+
ext = CollectionSearchPostExtension.from_extensions(
454+
extensions,
455+
client=DummyPostClient(),
456+
settings=ApiSettings(),
457+
)
458+
collection_search = ext.POST()
459+
assert collection_search.__class__.__name__ == "CollectionsPostRequest"
460+
assert hasattr(collection_search, "bbox")
461+
assert hasattr(collection_search, "datetime")
462+
assert hasattr(collection_search, "limit")
463+
assert hasattr(collection_search, "fields")
464+
assert hasattr(collection_search, "q")
465+
assert hasattr(collection_search, "sortby")
466+
assert hasattr(collection_search, "filter")
467+
assert ext.conformance_classes == [
468+
ConformanceClasses.COLLECTIONSEARCH,
469+
ConformanceClasses.BASIS,
470+
ConformanceClasses.FIELDS,
471+
ConformanceClasses.FILTER,
472+
ConformanceClasses.FREETEXT,
473+
ConformanceClasses.QUERY,
474+
ConformanceClasses.SORT,
475+
]
476+
477+
478+
def test_from_extensions_methods_invalid():
479+
"""Should raise warnings for invalid extensions."""
480+
extensions = [
481+
AggregationExtension(),
482+
]
483+
with pytest.warns((UserWarning)):
484+
ext = CollectionSearchExtension.from_extensions(
485+
extensions,
486+
)
487+
collection_search = ext.GET()
488+
assert collection_search.__class__.__name__ == "CollectionsGetRequest"
489+
assert hasattr(collection_search, "bbox")
490+
assert hasattr(collection_search, "datetime")
491+
assert hasattr(collection_search, "limit")
492+
assert ext.conformance_classes == [
493+
ConformanceClasses.COLLECTIONSEARCH,
494+
ConformanceClasses.BASIS,
495+
]
496+
497+
with pytest.warns((UserWarning)):
498+
ext = CollectionSearchPostExtension.from_extensions(
499+
extensions,
500+
client=DummyPostClient(),
501+
settings=ApiSettings(),
502+
)
503+
collection_search = ext.POST()
504+
assert collection_search.__class__.__name__ == "CollectionsPostRequest"
505+
assert hasattr(collection_search, "bbox")
506+
assert hasattr(collection_search, "datetime")
507+
assert hasattr(collection_search, "limit")
508+
assert ext.conformance_classes == [
509+
ConformanceClasses.COLLECTIONSEARCH,
510+
ConformanceClasses.BASIS,
511+
]

0 commit comments

Comments
 (0)