From 0f4cbb7072efcd031bebb293764aa3e519707ad8 Mon Sep 17 00:00:00 2001 From: rhysrevans3 Date: Wed, 22 May 2024 08:45:33 +0100 Subject: [PATCH 1/5] Moving database logic to db for collection endpoints. --- .../stac_fastapi/core/base_database_logic.py | 19 +- stac_fastapi/core/stac_fastapi/core/core.py | 328 ++---------------- .../elasticsearch/database_logic.py | 182 ++++++++-- 3 files changed, 196 insertions(+), 333 deletions(-) diff --git a/stac_fastapi/core/stac_fastapi/core/base_database_logic.py b/stac_fastapi/core/stac_fastapi/core/base_database_logic.py index 0043cfb8..d26d91eb 100644 --- a/stac_fastapi/core/stac_fastapi/core/base_database_logic.py +++ b/stac_fastapi/core/stac_fastapi/core/base_database_logic.py @@ -20,8 +20,8 @@ async def get_all_collections( pass @abc.abstractmethod - async def get_one_item(self, collection_id: str, item_id: str) -> Dict: - """Retrieve a single item from the database.""" + async def get_item(self, collection_id: str, item_id: str) -> Dict: + """Retrieve an item from the database.""" pass @abc.abstractmethod @@ -37,13 +37,20 @@ async def delete_item( pass @abc.abstractmethod - async def create_collection(self, collection: Dict, refresh: bool = False) -> None: - """Create a collection in the database.""" + async def item_search( + self, item_id: str, collection_id: str, refresh: bool = False + ) -> None: + """Retrieve items that match query from the database.""" pass @abc.abstractmethod - async def find_collection(self, collection_id: str) -> Dict: - """Find a collection in the database.""" + async def get_collection(self, collection_id: str) -> Dict: + """Retrieve a collection in the database.""" + pass + + @abc.abstractmethod + async def create_collection(self, collection: Dict, refresh: bool = False) -> None: + """Create a collection in the database.""" pass @abc.abstractmethod diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index 5469bf10..fd69a96c 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -1,4 +1,5 @@ """Core client.""" + import logging import re from datetime import datetime as datetime_type @@ -61,8 +62,6 @@ class CoreClient(AsyncBaseCoreClient): session (Session): A requests session instance to be used for all HTTP requests. item_serializer (Type[serializers.ItemSerializer]): A serializer class to be used to convert between STAC items and database records. - collection_serializer (Type[serializers.CollectionSerializer]): A serializer class to be - used to convert between STAC collections and database records. database (DatabaseLogic): An instance of the `DatabaseLogic` class that is used to interact with the database. """ @@ -75,121 +74,12 @@ class CoreClient(AsyncBaseCoreClient): session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) item_serializer: Type[ItemSerializer] = attr.ib(default=ItemSerializer) - collection_serializer: Type[CollectionSerializer] = attr.ib( - default=CollectionSerializer - ) post_request_model = attr.ib(default=BaseSearchPostRequest) stac_version: str = attr.ib(default=STAC_VERSION) landing_page_id: str = attr.ib(default="stac-fastapi") title: str = attr.ib(default="stac-fastapi") description: str = attr.ib(default="stac-fastapi") - def _landing_page( - self, - base_url: str, - conformance_classes: List[str], - extension_schemas: List[str], - ) -> stac_types.LandingPage: - landing_page = stac_types.LandingPage( - type="Catalog", - id=self.landing_page_id, - title=self.title, - description=self.description, - stac_version=self.stac_version, - conformsTo=conformance_classes, - links=[ - { - "rel": Relations.self.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.root.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": "data", - "type": MimeTypes.json, - "href": urljoin(base_url, "collections"), - }, - { - "rel": Relations.conformance.value, - "type": MimeTypes.json, - "title": "STAC/WFS3 conformance classes implemented by this server", - "href": urljoin(base_url, "conformance"), - }, - { - "rel": Relations.search.value, - "type": MimeTypes.geojson, - "title": "STAC search", - "href": urljoin(base_url, "search"), - "method": "GET", - }, - { - "rel": Relations.search.value, - "type": MimeTypes.geojson, - "title": "STAC search", - "href": urljoin(base_url, "search"), - "method": "POST", - }, - ], - stac_extensions=extension_schemas, - ) - return landing_page - - async def landing_page(self, **kwargs) -> stac_types.LandingPage: - """Landing page. - - Called with `GET /`. - - Returns: - API landing page, serving as an entry point to the API. - """ - request: Request = kwargs["request"] - base_url = get_base_url(request) - landing_page = self._landing_page( - base_url=base_url, - conformance_classes=self.conformance_classes(), - extension_schemas=[], - ) - collections = await self.all_collections(request=kwargs["request"]) - for collection in collections["collections"]: - landing_page["links"].append( - { - "rel": Relations.child.value, - "type": MimeTypes.json.value, - "title": collection.get("title") or collection.get("id"), - "href": urljoin(base_url, f"collections/{collection['id']}"), - } - ) - - # Add OpenAPI URL - landing_page["links"].append( - { - "rel": "service-desc", - "type": "application/vnd.oai.openapi+json;version=3.0", - "title": "OpenAPI service description", - "href": urljoin( - str(request.base_url), request.app.openapi_url.lstrip("/") - ), - } - ) - - # Add human readable service-doc - landing_page["links"].append( - { - "rel": "service-doc", - "type": "text/html", - "title": "OpenAPI service documentation", - "href": urljoin( - str(request.base_url), request.app.docs_url.lstrip("/") - ), - } - ) - - return landing_page - async def all_collections(self, **kwargs) -> stac_types.Collections: """Read all collections from the database. @@ -199,30 +89,7 @@ async def all_collections(self, **kwargs) -> stac_types.Collections: Returns: A Collections object containing all the collections in the database and links to various resources. """ - request = kwargs["request"] - base_url = str(request.base_url) - limit = int(request.query_params.get("limit", 10)) - token = request.query_params.get("token") - - collections, next_token = await self.database.get_all_collections( - token=token, limit=limit, base_url=base_url - ) - - links = [ - {"rel": Relations.root.value, "type": MimeTypes.json, "href": base_url}, - {"rel": Relations.parent.value, "type": MimeTypes.json, "href": base_url}, - { - "rel": Relations.self.value, - "type": MimeTypes.json, - "href": urljoin(base_url, "collections"), - }, - ] - - if next_token: - next_link = PagingLinks(next=next_token, request=request).link_next() - links.append(next_link) - - return stac_types.Collections(collections=collections, links=links) + return await self.database.get_all_collections(request=kwargs["request"]) async def get_collection( self, collection_id: str, **kwargs @@ -239,10 +106,8 @@ async def get_collection( Raises: NotFoundError: If the collection with the given id cannot be found in the database. """ - base_url = str(kwargs["request"].base_url) - collection = await self.database.find_collection(collection_id=collection_id) - return self.collection_serializer.db_to_stac( - collection=collection, base_url=base_url + return await self.database.find_collection( + collection_id=collection_id, request=kwargs["request"] ) async def item_collection( @@ -272,54 +137,13 @@ async def item_collection( HTTPException: If the specified collection is not found. Exception: If any error occurs while reading the items from the database. """ - request: Request = kwargs["request"] - base_url = str(request.base_url) - - collection = await self.get_collection( - collection_id=collection_id, request=request - ) - collection_id = collection.get("id") - if collection_id is None: - raise HTTPException(status_code=404, detail="Collection not found") - - search = self.database.make_search() - search = self.database.apply_collections_filter( - search=search, collection_ids=[collection_id] - ) - - if datetime: - datetime_search = self._return_date(datetime) - search = self.database.apply_datetime_filter( - search=search, datetime_search=datetime_search - ) - - if bbox: - bbox = [float(x) for x in bbox] - if len(bbox) == 6: - bbox = [bbox[0], bbox[1], bbox[3], bbox[4]] - - search = self.database.apply_bbox_filter(search=search, bbox=bbox) - - items, maybe_count, next_token = await self.database.execute_search( - search=search, + return self.get_search( + request=kwargs["request"], + collections=[collection_id], + bbox=bbox, + datetime=datetime, limit=limit, - sort=None, - token=token, # type: ignore - collection_ids=[collection_id], - ) - - items = [ - self.item_serializer.db_to_stac(item, base_url=base_url) for item in items - ] - - links = await PagingLinks(request=request, next=next_token).get_links() - - return stac_types.ItemCollection( - type="FeatureCollection", - features=items, - links=links, - numReturned=len(items), - numMatched=maybe_count, + token=token, ) async def get_item( @@ -531,108 +355,8 @@ async def post_search( Raises: HTTPException: If there is an error with the cql2_json filter. """ - base_url = str(request.base_url) - - search = self.database.make_search() - - if search_request.ids: - search = self.database.apply_ids_filter( - search=search, item_ids=search_request.ids - ) - - if search_request.collections: - search = self.database.apply_collections_filter( - search=search, collection_ids=search_request.collections - ) - - if search_request.datetime: - datetime_search = self._return_date(search_request.datetime) - search = self.database.apply_datetime_filter( - search=search, datetime_search=datetime_search - ) - - if search_request.bbox: - bbox = search_request.bbox - if len(bbox) == 6: - bbox = [bbox[0], bbox[1], bbox[3], bbox[4]] - - search = self.database.apply_bbox_filter(search=search, bbox=bbox) - - if search_request.intersects: - search = self.database.apply_intersects_filter( - search=search, intersects=search_request.intersects - ) - - if search_request.query: - for field_name, expr in search_request.query.items(): - field = "properties__" + field_name - for op, value in expr.items(): - # Convert enum to string - operator = op.value if isinstance(op, Enum) else op - search = self.database.apply_stacql_filter( - search=search, op=operator, field=field, value=value - ) - - # only cql2_json is supported here - if hasattr(search_request, "filter"): - cql2_filter = getattr(search_request, "filter", None) - try: - search = self.database.apply_cql2_filter(search, cql2_filter) - except Exception as e: - raise HTTPException( - status_code=400, detail=f"Error with cql2_json filter: {e}" - ) - - sort = None - if search_request.sortby: - sort = self.database.populate_sort(search_request.sortby) - - limit = 10 - if search_request.limit: - limit = search_request.limit - - items, maybe_count, next_token = await self.database.execute_search( - search=search, - limit=limit, - token=search_request.token, # type: ignore - sort=sort, - collection_ids=search_request.collections, - ) - - items = [ - self.item_serializer.db_to_stac(item, base_url=base_url) for item in items - ] - - if self.extension_is_enabled("FieldsExtension"): - if search_request.query is not None: - query_include: Set[str] = set( - [ - k if k in Settings.get().indexed_fields else f"properties.{k}" - for k in search_request.query.keys() - ] - ) - if not search_request.fields.include: - search_request.fields.include = query_include - else: - search_request.fields.include.union(query_include) - - filter_kwargs = search_request.fields.filter_fields - - items = [ - orjson.loads( - stac_pydantic.Item(**feat).json(**filter_kwargs, exclude_unset=True) - ) - for feat in items - ] - - links = await PagingLinks(request=request, next=next_token).get_links() - - return stac_types.ItemCollection( - type="FeatureCollection", - features=items, - links=links, - numReturned=len(items), - numMatched=maybe_count, + return await self.database.execute_search( + search_request=search_request, request=request ) @@ -747,13 +471,9 @@ async def create_collection( Raises: ConflictError: If the collection already exists. """ - collection = collection.model_dump(mode="json") - base_url = str(kwargs["request"].base_url) - collection = self.database.collection_serializer.stac_to_db( - collection, base_url + return await self.database.create_collection( + collection=collection, request=kwargs["request"] ) - await self.database.create_collection(collection=collection) - return CollectionSerializer.db_to_stac(collection, base_url) @overrides async def update_collection( @@ -778,19 +498,12 @@ async def update_collection( A STAC collection that has been updated in the database. """ - collection = collection.model_dump(mode="json") - - base_url = str(kwargs["request"].base_url) - - collection = self.database.collection_serializer.stac_to_db( - collection, base_url - ) - await self.database.update_collection( - collection_id=collection_id, collection=collection + return await self.database.update_collection( + collection_id=collection_id, + collection=collection, + request=kwargs["request"], ) - return CollectionSerializer.db_to_stac(collection, base_url) - @overrides async def delete_collection( self, collection_id: str, **kwargs @@ -810,8 +523,9 @@ async def delete_collection( Raises: NotFoundError: If the collection doesn't exist. """ - await self.database.delete_collection(collection_id=collection_id) - return None + return await self.database.delete_collection( + collection_id=collection_id, request=kwargs["request"] + ) @attr.s diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index ddb6648b..eacf105c 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -1,15 +1,20 @@ """Database logic.""" + import asyncio import logging import os from base64 import urlsafe_b64decode, urlsafe_b64encode +from mimetypes import MimeTypes from typing import Any, Dict, Iterable, List, Optional, Protocol, Tuple, Type, Union +from urllib.parse import urljoin import attr from elasticsearch_dsl import Q, Search +from fastapi import Request from elasticsearch import exceptions, helpers # type: ignore from stac_fastapi.core.extensions import filter +from stac_fastapi.core.models.links import PagingLinks from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer from stac_fastapi.core.utilities import MAX_LIMIT, bbox2polygon from stac_fastapi.elasticsearch.config import AsyncElasticsearchSettings @@ -315,7 +320,7 @@ class DatabaseLogic: """CORE LOGIC""" async def get_all_collections( - self, token: Optional[str], limit: int, base_url: str + self, request: Request ) -> Tuple[List[Dict[str, Any]], Optional[str]]: """Retrieve a list of all collections from Elasticsearch, supporting pagination. @@ -326,6 +331,10 @@ async def get_all_collections( Returns: A tuple of (collections, next pagination token if any). """ + base_url = str(request.base_url) + limit = int(request.query_params.get("limit", 10)) + token = request.query_params.get("token") + search_after = None if token: search_after = [token] @@ -351,9 +360,23 @@ async def get_all_collections( if len(hits) == limit: next_token = hits[-1]["sort"][0] - return collections, next_token + links = [ + {"rel": Relations.root.value, "type": MimeTypes.json, "href": base_url}, + {"rel": Relations.parent.value, "type": MimeTypes.json, "href": base_url}, + { + "rel": Relations.self.value, + "type": MimeTypes.json, + "href": urljoin(base_url, "collections"), + }, + ] + + if next_token: + next_link = PagingLinks(next=next_token, request=request).link_next() + links.append(next_link) - async def get_one_item(self, collection_id: str, item_id: str) -> Dict: + return stac_types.Collections(collections=collections, links=links) + + async def get_item(self, collection_id: str, item_id: str) -> Dict: """Retrieve a single item from the database. Args: @@ -542,11 +565,8 @@ def populate_sort(sortby: List) -> Optional[Dict[str, Dict[str, str]]]: async def execute_search( self, - search: Search, - limit: int, - token: Optional[str], - sort: Optional[Dict[str, Dict[str, str]]], - collection_ids: Optional[List[str]], + search_request: BaseSearchPostRequest, + request: Request, ignore_unavailable: bool = True, ) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str]]: """Execute a search query with limit and other optional parameters. @@ -570,8 +590,71 @@ async def execute_search( Raises: NotFoundError: If the collections specified in `collection_ids` do not exist. """ - search_after = None + search = self.make_search() + + if search_request.ids: + search = self.apply_ids_filter(search=search, item_ids=search_request.ids) + + if search_request.collections: + search = self.apply_collections_filter( + search=search, collection_ids=search_request.collections + ) + + if search_request.datetime: + datetime_search = self._return_date(search_request.datetime) + search = self.apply_datetime_filter( + search=search, datetime_search=datetime_search + ) + + if search_request.bbox: + bbox = search_request.bbox + if len(bbox) == 6: + bbox = [bbox[0], bbox[1], bbox[3], bbox[4]] + + search = self.apply_bbox_filter(search=search, bbox=bbox) + + if search_request.intersects: + search = self.apply_intersects_filter( + search=search, intersects=search_request.intersects + ) + + if search_request.query: + for field_name, expr in search_request.query.items(): + field = "properties__" + field_name + for op, value in expr.items(): + # Convert enum to string + operator = op.value if isinstance(op, Enum) else op + search = self.apply_stacql_filter( + search=search, op=operator, field=field, value=value + ) + + # only cql2_json is supported here + if hasattr(search_request, "filter"): + cql2_filter = getattr(search_request, "filter", None) + try: + search = self.apply_cql2_filter(search, cql2_filter) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Error with cql2_json filter: {e}" + ) + + sort = None + if search_request.sortby: + sort = self.populate_sort(search_request.sortby) + + limit = 10 + if search_request.limit: + limit = search_request.limit + + items, maybe_count, next_token = await self.database.execute_search( + search=search, + limit=limit, + token=search_request.token, # type: ignore + sort=sort, + collection_ids=search_request.collections, + ) + search_after = None if token: search_after = urlsafe_b64decode(token.encode()).decode().split(",") @@ -628,7 +711,42 @@ async def execute_search( except Exception as e: logger.error(f"Count task failed: {e}") - return items, matched, next_token + items = [ + self.item_serializer.db_to_stac(item, base_url=str(request.base_url)) + for item in items + ] + + if self.extension_is_enabled("FieldsExtension"): + if search_request.query is not None: + query_include: Set[str] = set( + [ + k if k in Settings.get().indexed_fields else f"properties.{k}" + for k in search_request.query.keys() + ] + ) + if not search_request.fields.include: + search_request.fields.include = query_include + else: + search_request.fields.include.union(query_include) + + filter_kwargs = search_request.fields.filter_fields + + items = [ + orjson.loads( + stac_pydantic.Item(**feat).json(**filter_kwargs, exclude_unset=True) + ) + for feat in items + ] + + links = await PagingLinks(request=request, next=next_token).get_links() + + return stac_types.ItemCollection( + type="FeatureCollection", + features=items, + links=links, + numReturned=len(items), + numMatched=maybe_count, + ) """ TRANSACTION LOGIC """ @@ -756,7 +874,9 @@ async def delete_item( f"Item {item_id} in collection {collection_id} not found" ) - async def create_collection(self, collection: Collection, refresh: bool = False): + async def create_collection( + self, collection: Collection, request: Request, refresh: bool = False + ): """Create a single collection in the database. Args: @@ -769,6 +889,7 @@ async def create_collection(self, collection: Collection, refresh: bool = False) Notes: A new index is created for the items in the Collection using the `create_item_index` function. """ + collection = collection.model_dump(mode="json") collection_id = collection["id"] if await self.client.exists(index=COLLECTIONS_INDEX, id=collection_id): @@ -783,7 +904,11 @@ async def create_collection(self, collection: Collection, refresh: bool = False) await create_item_index(collection_id) - async def find_collection(self, collection_id: str) -> Collection: + return self.collection_serializer.db_to_stac( + collection=collection, base_url=str(request.base_url) + ) + + async def get_collection(self, collection_id: str, request: Request) -> Collection: """Find and return a collection from the database. Args: @@ -804,13 +929,19 @@ async def find_collection(self, collection_id: str) -> Collection: collection = await self.client.get( index=COLLECTIONS_INDEX, id=collection_id ) - except exceptions.NotFoundError: - raise NotFoundError(f"Collection {collection_id} not found") + except exceptions.NotFoundError as exc: + raise NotFoundError(f"Collection {collection_id} not found") from exc - return collection["_source"] + return self.collection_serializer.db_to_stac( + collection=collection["_source"], base_url=str(request.base_url) + ) async def update_collection( - self, collection_id: str, collection: Collection, refresh: bool = False + self, + collection_id: str, + collection: Collection, + request: Request, + refresh: bool = False, ): """Update a collection from the database. @@ -828,10 +959,14 @@ async def update_collection( `collection_id` and with the collection specified in the `Collection` object. If the collection is not found, a `NotFoundError` is raised. """ - await self.find_collection(collection_id=collection_id) + await self.get_collection(collection_id=collection_id, request=request) + + collection = collection.model_dump(mode="json") if collection_id != collection["id"]: - await self.create_collection(collection, refresh=refresh) + await self.create_collection( + collection=collection, request=request, refresh=refresh + ) await self.client.reindex( body={ @@ -856,7 +991,13 @@ async def update_collection( refresh=refresh, ) - async def delete_collection(self, collection_id: str, refresh: bool = False): + return self.collection_serializer.db_to_stac( + collection=collection, base_url=str(request.base_url) + ) + + async def delete_collection( + self, collection_id: str, request: Request, refresh: bool = False + ): """Delete a collection from the database. Parameters: @@ -872,11 +1013,12 @@ async def delete_collection(self, collection_id: str, refresh: bool = False): deletes the collection. If `refresh` is set to True, the index is refreshed after the deletion. Additionally, this function also calls `delete_item_index` to delete the index for the items in the collection. """ - await self.find_collection(collection_id=collection_id) + await self.get_collection(collection_id=collection_id, request=request) await self.client.delete( index=COLLECTIONS_INDEX, id=collection_id, refresh=refresh ) await delete_item_index(collection_id) + return None async def bulk_async( self, collection_id: str, processed_items: List[Item], refresh: bool = False From 40f2813ef38a9431d77f60ded5517132d487e38f Mon Sep 17 00:00:00 2001 From: rhysrevans3 Date: Wed, 12 Jun 2024 12:03:13 +0100 Subject: [PATCH 2/5] Moving item database code to database layer. --- stac_fastapi/core/stac_fastapi/core/core.py | 104 +++------------- .../core/stac_fastapi/core/utilities.py | 56 ++++++++- .../elasticsearch/database_logic.py | 111 ++++++++++++------ 3 files changed, 149 insertions(+), 122 deletions(-) diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index fd69a96c..c79a3206 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -2,29 +2,22 @@ import logging import re -from datetime import datetime as datetime_type -from datetime import timezone -from enum import Enum -from typing import Any, Dict, List, Optional, Set, Type, Union -from urllib.parse import unquote_plus, urljoin +from typing import Any, Dict, List, Optional, Union +from urllib.parse import unquote_plus import attr import orjson -import stac_pydantic from fastapi import HTTPException, Request from overrides import overrides from pydantic import ValidationError from pygeofilter.backends.cql2_json import to_cql2 from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from stac_pydantic import Collection, Item, ItemCollection -from stac_pydantic.links import Relations -from stac_pydantic.shared import BBox, MimeTypes +from stac_pydantic.shared import BBox from stac_pydantic.version import STAC_VERSION from stac_fastapi.core.base_database_logic import BaseDatabaseLogic from stac_fastapi.core.base_settings import ApiBaseSettings -from stac_fastapi.core.models.links import PagingLinks -from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer from stac_fastapi.core.session import Session from stac_fastapi.extensions.third_party.bulk_transactions import ( BaseBulkTransactionsClient, @@ -32,7 +25,6 @@ Items, ) from stac_fastapi.types import stac as stac_types -from stac_fastapi.types.config import Settings from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES from stac_fastapi.types.core import ( AsyncBaseCoreClient, @@ -40,7 +32,6 @@ AsyncBaseTransactionsClient, ) from stac_fastapi.types.extension import ApiExtension -from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.rfc3339 import DateTimeType from stac_fastapi.types.search import BaseSearchPostRequest @@ -55,13 +46,10 @@ class CoreClient(AsyncBaseCoreClient): This class is a implementation of `AsyncBaseCoreClient` that implements the core endpoints defined by the STAC specification. It uses the `DatabaseLogic` class to interact with the - database, and `ItemSerializer` and `CollectionSerializer` to convert between STAC objects and database records. Attributes: session (Session): A requests session instance to be used for all HTTP requests. - item_serializer (Type[serializers.ItemSerializer]): A serializer class to be used to convert - between STAC items and database records. database (DatabaseLogic): An instance of the `DatabaseLogic` class that is used to interact with the database. """ @@ -73,7 +61,6 @@ class CoreClient(AsyncBaseCoreClient): extensions: List[ApiExtension] = attr.ib(default=attr.Factory(list)) session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) - item_serializer: Type[ItemSerializer] = attr.ib(default=ItemSerializer) post_request_model = attr.ib(default=BaseSearchPostRequest) stac_version: str = attr.ib(default=STAC_VERSION) landing_page_id: str = attr.ib(default="stac-fastapi") @@ -106,7 +93,7 @@ async def get_collection( Raises: NotFoundError: If the collection with the given id cannot be found in the database. """ - return await self.database.find_collection( + return await self.database.get_collection( collection_id=collection_id, request=kwargs["request"] ) @@ -162,63 +149,9 @@ async def get_item( Exception: If any error occurs while getting the item from the database. NotFoundError: If the item does not exist in the specified collection. """ - base_url = str(kwargs["request"].base_url) - item = await self.database.get_one_item( - item_id=item_id, collection_id=collection_id + return await self.database.get_item( + item_id=item_id, collection_id=collection_id, request=kwargs["request"] ) - return self.item_serializer.db_to_stac(item, base_url) - - @staticmethod - def _return_date( - interval: Optional[Union[DateTimeType, str]] - ) -> Dict[str, Optional[str]]: - """ - Convert a date interval. - - (which may be a datetime, a tuple of one or two datetimes a string - representing a datetime or range, or None) into a dictionary for filtering - search results with Elasticsearch. - - This function ensures the output dictionary contains 'gte' and 'lte' keys, - even if they are set to None, to prevent KeyError in the consuming logic. - - Args: - interval (Optional[Union[DateTimeType, str]]): The date interval, which might be a single datetime, - a tuple with one or two datetimes, a string, or None. - - Returns: - dict: A dictionary representing the date interval for use in filtering search results, - always containing 'gte' and 'lte' keys. - """ - result: Dict[str, Optional[str]] = {"gte": None, "lte": None} - - if interval is None: - return result - - if isinstance(interval, str): - if "/" in interval: - parts = interval.split("/") - result["gte"] = parts[0] if parts[0] != ".." else None - result["lte"] = ( - parts[1] if len(parts) > 1 and parts[1] != ".." else None - ) - else: - converted_time = interval if interval != ".." else None - result["gte"] = result["lte"] = converted_time - return result - - if isinstance(interval, datetime_type): - datetime_iso = interval.isoformat() - result["gte"] = result["lte"] = datetime_iso - elif isinstance(interval, tuple): - start, end = interval - # Ensure datetimes are converted to UTC and formatted with 'Z' - if start: - result["gte"] = start.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" - if end: - result["lte"] = end.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" - - return result def _format_datetime_range(self, date_tuple: DateTimeType) -> str: """ @@ -335,9 +268,7 @@ async def get_search( search_request = self.post_request_model(**base_args) except ValidationError: raise HTTPException(status_code=400, detail="Invalid parameters provided") - resp = await self.post_search(search_request=search_request, request=request) - - return resp + return await self.post_search(search_request=search_request, request=request) async def post_search( self, search_request: BaseSearchPostRequest, request: Request @@ -405,9 +336,9 @@ async def create_item( return None else: - item = await self.database.prep_create_item(item=item, base_url=base_url) - await self.database.create_item(item, refresh=kwargs.get("refresh", False)) - return ItemSerializer.db_to_stac(item, base_url) + return await self.database.create_item( + item, refresh=kwargs.get("refresh", False), request=kwargs["request"] + ) @overrides async def update_item( @@ -428,16 +359,13 @@ async def update_item( NotFound: If the specified collection is not found in the database. """ - item = item.model_dump(mode="json") - base_url = str(kwargs["request"].base_url) - now = datetime_type.now(timezone.utc).isoformat().replace("+00:00", "Z") - item["properties"]["updated"] = now - - await self.database.check_collection_exists(collection_id) - await self.delete_item(item_id=item_id, collection_id=collection_id) - await self.create_item(collection_id=collection_id, item=Item(**item), **kwargs) - return ItemSerializer.db_to_stac(item, base_url) + return self.database.update_item( + collection_id=collection_id, + item_id=item, + item=item, + request=kwargs["request"], + ) @overrides async def delete_item( diff --git a/stac_fastapi/core/stac_fastapi/core/utilities.py b/stac_fastapi/core/stac_fastapi/core/utilities.py index faa4f6a9..a9b07c37 100644 --- a/stac_fastapi/core/stac_fastapi/core/utilities.py +++ b/stac_fastapi/core/stac_fastapi/core/utilities.py @@ -3,7 +3,11 @@ This module contains functions for transforming geospatial coordinates, such as converting bounding boxes to polygon representations. """ -from typing import List + +from datetime import datetime +from typing import Dict, List, Optional, Union + +from stac_fastapi.types.rfc3339 import DateTimeType MAX_LIMIT = 10000 @@ -21,3 +25,53 @@ def bbox2polygon(b0: float, b1: float, b2: float, b3: float) -> List[List[List[f List[List[List[float]]]: A polygon represented as a list of lists of coordinates. """ return [[[b0, b1], [b2, b1], [b2, b3], [b0, b3], [b0, b1]]] + + +def return_date( + interval: Optional[Union[DateTimeType, str]] +) -> Dict[str, Optional[str]]: + """ + Convert a date interval. + + (which may be a datetime, a tuple of one or two datetimes a string + representing a datetime or range, or None) into a dictionary for filtering + search results with Elasticsearch. + + This function ensures the output dictionary contains 'gte' and 'lte' keys, + even if they are set to None, to prevent KeyError in the consuming logic. + + Args: + interval (Optional[Union[DateTimeType, str]]): The date interval, which might be a single datetime, + a tuple with one or two datetimes, a string, or None. + + Returns: + dict: A dictionary representing the date interval for use in filtering search results, + always containing 'gte' and 'lte' keys. + """ + result: Dict[str, Optional[str]] = {"gte": None, "lte": None} + + if interval is None: + return result + + if isinstance(interval, str): + if "/" in interval: + parts = interval.split("/") + result["gte"] = parts[0] if parts[0] != ".." else None + result["lte"] = parts[1] if len(parts) > 1 and parts[1] != ".." else None + else: + converted_time = interval if interval != ".." else None + result["gte"] = result["lte"] = converted_time + return result + + if isinstance(interval, datetime): + datetime_iso = interval.isoformat() + result["gte"] = result["lte"] = datetime_iso + elif isinstance(interval, tuple): + start, end = interval + # Ensure datetimes are converted to UTC and formatted with 'Z' + if start: + result["gte"] = start.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + if end: + result["lte"] = end.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + + return result diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index eacf105c..db140399 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -4,25 +4,43 @@ import logging import os from base64 import urlsafe_b64decode, urlsafe_b64encode +from datetime import datetime as datetime_type +from datetime import timezone from mimetypes import MimeTypes -from typing import Any, Dict, Iterable, List, Optional, Protocol, Tuple, Type, Union +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Protocol, + Set, + Tuple, + Type, + Union, +) from urllib.parse import urljoin import attr +import orjson from elasticsearch_dsl import Q, Search from fastapi import Request +from stac_pydantic.links import Relations +from stac_pydantic.shared import MimeTypes from elasticsearch import exceptions, helpers # type: ignore from stac_fastapi.core.extensions import filter from stac_fastapi.core.models.links import PagingLinks from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer -from stac_fastapi.core.utilities import MAX_LIMIT, bbox2polygon +from stac_fastapi.core.utilities import MAX_LIMIT, bbox2polygon, return_date from stac_fastapi.elasticsearch.config import AsyncElasticsearchSettings from stac_fastapi.elasticsearch.config import ( ElasticsearchSettings as SyncElasticsearchSettings, ) +from stac_fastapi.types.config import Settings from stac_fastapi.types.errors import ConflictError, NotFoundError -from stac_fastapi.types.stac import Collection, Item +from stac_fastapi.types.search import BaseSearchPostRequest +from stac_fastapi.types.stac import Collection, Item, ItemCollection logger = logging.getLogger(__name__) @@ -376,7 +394,9 @@ async def get_all_collections( return stac_types.Collections(collections=collections, links=links) - async def get_item(self, collection_id: str, item_id: str) -> Dict: + async def get_item( + self, collection_id: str, item_id: str, request: Request + ) -> Dict: """Retrieve a single item from the database. Args: @@ -402,7 +422,9 @@ async def get_item(self, collection_id: str, item_id: str) -> Dict: raise NotFoundError( f"Item {item_id} does not exist in Collection {collection_id}" ) - return item["_source"] + return self.item_serializer.db_to_stac( + item=item["_source"], base_url=str(request.base_url) + ) @staticmethod def make_search(): @@ -601,7 +623,7 @@ async def execute_search( ) if search_request.datetime: - datetime_search = self._return_date(search_request.datetime) + datetime_search = return_date(search_request.datetime) search = self.apply_datetime_filter( search=search, datetime_search=datetime_search ) @@ -633,10 +655,8 @@ async def execute_search( cql2_filter = getattr(search_request, "filter", None) try: search = self.apply_cql2_filter(search, cql2_filter) - except Exception as e: - raise HTTPException( - status_code=400, detail=f"Error with cql2_json filter: {e}" - ) + except Exception as exc: + raise Exception("Error with cql2_json filter") from exc sort = None if search_request.sortby: @@ -646,21 +666,15 @@ async def execute_search( if search_request.limit: limit = search_request.limit - items, maybe_count, next_token = await self.database.execute_search( - search=search, - limit=limit, - token=search_request.token, # type: ignore - sort=sort, - collection_ids=search_request.collections, - ) - search_after = None - if token: - search_after = urlsafe_b64decode(token.encode()).decode().split(",") + if hasattr(search_request, "token"): + search_after = ( + urlsafe_b64decode(search_request.token.encode()).decode().split(",") + ) query = search.query.to_dict() if search.query else None - index_param = indices(collection_ids) + index_param = indices(search_request.collections) max_result_window = MAX_LIMIT @@ -687,8 +701,10 @@ async def execute_search( try: es_response = await search_task - except exceptions.NotFoundError: - raise NotFoundError(f"Collections '{collection_ids}' do not exist") + except exceptions.NotFoundError as exc: + raise NotFoundError( + f"Collections '{search_request.collections}' do not exist" + ) from exc hits = es_response["hits"]["hits"] items = (hit["_source"] for hit in hits[:limit]) @@ -708,8 +724,8 @@ async def execute_search( if count_task.done(): try: matched = count_task.result().get("count") - except Exception as e: - logger.error(f"Count task failed: {e}") + except Exception as exc: + logger.error("Count task failed: %s", exc) items = [ self.item_serializer.db_to_stac(item, base_url=str(request.base_url)) @@ -732,20 +748,18 @@ async def execute_search( filter_kwargs = search_request.fields.filter_fields items = [ - orjson.loads( - stac_pydantic.Item(**feat).json(**filter_kwargs, exclude_unset=True) - ) - for feat in items + orjson.loads(item.json(**filter_kwargs, exclude_unset=True)) + for item in items ] links = await PagingLinks(request=request, next=next_token).get_links() - return stac_types.ItemCollection( + return ItemCollection( type="FeatureCollection", features=items, links=links, numReturned=len(items), - numMatched=maybe_count, + numMatched=matched, ) """ TRANSACTION LOGIC """ @@ -822,7 +836,7 @@ def sync_prep_create_item( return self.item_serializer.stac_to_db(item, base_url) - async def create_item(self, item: Item, refresh: bool = False): + async def create_item(self, item: Item, request: Request, refresh: bool = False): """Database logic for creating one item. Args: @@ -836,6 +850,8 @@ async def create_item(self, item: Item, refresh: bool = False): None """ # todo: check if collection exists, but cache + self.prep_create_item(item=item, base_url=request.base_url) + item_id = item["id"] collection_id = item["collection"] es_resp = await self.client.index( @@ -850,6 +866,8 @@ async def create_item(self, item: Item, refresh: bool = False): f"Item {item_id} in collection {collection_id} already exists" ) + return self.item_serializer.db_to_stac(item, request.base_url) + async def delete_item( self, item_id: str, collection_id: str, refresh: bool = False ): @@ -874,6 +892,33 @@ async def delete_item( f"Item {item_id} in collection {collection_id} not found" ) + async def update_item( + self, + item_id: str, + collection_id: str, + item: Item, + request: Request, + refresh: bool = False, + ): + """Update an item from the database. + + Args: + item_id (str): The id of the Item to be updated. + collection_id (str): The id of the Collection that the Item belongs to. + item (Item): The updated item. + refresh (bool, optional): Whether to refresh the index after the deletion. Default is False. + + Raises: + NotFoundError: If the Item does not exist in the database. + """ + item.properties["updated"] = ( + datetime_type.now(timezone.utc).isoformat().replace("+00:00", "Z") + ) + + await self.check_collection_exists(collection_id) + await self.delete_item(item_id=item_id, collection_id=collection_id) + return await self.create_item(item=item, refresh=refresh, request=request) + async def create_collection( self, collection: Collection, request: Request, refresh: bool = False ): @@ -981,7 +1026,7 @@ async def update_collection( refresh=refresh, ) - await self.delete_collection(collection_id) + await self.delete_collection(collection_id=collection_id, request=request) else: await self.client.index( From c42b7c41fe2af45d6831485557da131f8fdf1757 Mon Sep 17 00:00:00 2001 From: rhysrevans3 Date: Fri, 14 Jun 2024 14:02:32 +0100 Subject: [PATCH 3/5] Adding extension check to db. --- stac_fastapi/core/stac_fastapi/core/core.py | 8 ++++++++ .../stac_fastapi/elasticsearch/database_logic.py | 13 +++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index c79a3206..3fa61c9c 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -67,6 +67,14 @@ class CoreClient(AsyncBaseCoreClient): title: str = attr.ib(default="stac-fastapi") description: str = attr.ib(default="stac-fastapi") + def __attrs_post_init__(self): + """Load extensions into database.""" + self.database = self.database.load_extensions(self.extensions) + + def extension_is_enabled(self, extension: str) -> bool: + """Check if an api extension is enabled.""" + return any([type(ext).__name__ == extension for ext in self.extensions]) + async def all_collections(self, **kwargs) -> stac_types.Collections: """Read all collections from the database. diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index db140399..520a11eb 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -40,7 +40,7 @@ from stac_fastapi.types.config import Settings from stac_fastapi.types.errors import ConflictError, NotFoundError from stac_fastapi.types.search import BaseSearchPostRequest -from stac_fastapi.types.stac import Collection, Item, ItemCollection +from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection logger = logging.getLogger(__name__) @@ -329,6 +329,7 @@ class DatabaseLogic: client = AsyncElasticsearchSettings().create_client sync_client = SyncElasticsearchSettings().create_client + extensions = attr.ib(default=[]) item_serializer: Type[ItemSerializer] = attr.ib(default=ItemSerializer) collection_serializer: Type[CollectionSerializer] = attr.ib( @@ -337,6 +338,14 @@ class DatabaseLogic: """CORE LOGIC""" + def load_extensions(self, extensions: list) -> None: + """Add extensions to current extensions list. + + Args: + extenstions (list): list of extensions to add. + """ + self.extensions.extend(extensions) + async def get_all_collections( self, request: Request ) -> Tuple[List[Dict[str, Any]], Optional[str]]: @@ -392,7 +401,7 @@ async def get_all_collections( next_link = PagingLinks(next=next_token, request=request).link_next() links.append(next_link) - return stac_types.Collections(collections=collections, links=links) + return Collections(collections=collections, links=links) async def get_item( self, collection_id: str, item_id: str, request: Request From 2eda01488dd3bf61f871cf6633ed5e86474f1e28 Mon Sep 17 00:00:00 2001 From: rhysrevans3 Date: Fri, 14 Jun 2024 14:06:21 +0100 Subject: [PATCH 4/5] Adding missing enum import. --- stac_fastapi/core/stac_fastapi/core/core.py | 4 ---- .../stac_fastapi/elasticsearch/database_logic.py | 5 +++++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index 3fa61c9c..e167beae 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -71,10 +71,6 @@ def __attrs_post_init__(self): """Load extensions into database.""" self.database = self.database.load_extensions(self.extensions) - def extension_is_enabled(self, extension: str) -> bool: - """Check if an api extension is enabled.""" - return any([type(ext).__name__ == extension for ext in self.extensions]) - async def all_collections(self, **kwargs) -> stac_types.Collections: """Read all collections from the database. diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 520a11eb..7cd5d6f7 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -6,6 +6,7 @@ from base64 import urlsafe_b64decode, urlsafe_b64encode from datetime import datetime as datetime_type from datetime import timezone +from enum import Enum from mimetypes import MimeTypes from typing import ( Any, @@ -346,6 +347,10 @@ def load_extensions(self, extensions: list) -> None: """ self.extensions.extend(extensions) + def extension_is_enabled(self, extension: str) -> bool: + """Check if an api extension is enabled.""" + return any([type(ext).__name__ == extension for ext in self.extensions]) + async def get_all_collections( self, request: Request ) -> Tuple[List[Dict[str, Any]], Optional[str]]: From 610ab15620980f371b8123f3bc7455f48e91f392 Mon Sep 17 00:00:00 2001 From: rhysrevans3 Date: Fri, 14 Jun 2024 14:30:08 +0100 Subject: [PATCH 5/5] Fixing precommits. --- stac_fastapi/core/stac_fastapi/core/core.py | 1 - .../stac_fastapi/elasticsearch/database_logic.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index e167beae..0d651266 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -363,7 +363,6 @@ async def update_item( NotFound: If the specified collection is not found in the database. """ - return self.database.update_item( collection_id=collection_id, item_id=item, diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 7cd5d6f7..9051104d 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -7,7 +7,6 @@ from datetime import datetime as datetime_type from datetime import timezone from enum import Enum -from mimetypes import MimeTypes from typing import ( Any, Dict, @@ -330,7 +329,7 @@ class DatabaseLogic: client = AsyncElasticsearchSettings().create_client sync_client = SyncElasticsearchSettings().create_client - extensions = attr.ib(default=[]) + extensions: list = attr.ib(default=[]) item_serializer: Type[ItemSerializer] = attr.ib(default=ItemSerializer) collection_serializer: Type[CollectionSerializer] = attr.ib( @@ -721,7 +720,7 @@ async def execute_search( ) from exc hits = es_response["hits"]["hits"] - items = (hit["_source"] for hit in hits[:limit]) + items = [hit["_source"] for hit in hits[:limit]] next_token = None if len(hits) > limit and limit < max_result_window: @@ -864,7 +863,7 @@ async def create_item(self, item: Item, request: Request, refresh: bool = False) None """ # todo: check if collection exists, but cache - self.prep_create_item(item=item, base_url=request.base_url) + await self.prep_create_item(item=item, base_url=request.base_url) item_id = item["id"] collection_id = item["collection"]