diff --git a/dockerfiles/Dockerfile.deploy.es b/dockerfiles/Dockerfile.deploy.es index 2eab7b9d..ef87d1c4 100644 --- a/dockerfiles/Dockerfile.deploy.es +++ b/dockerfiles/Dockerfile.deploy.es @@ -3,9 +3,12 @@ FROM python:3.10-slim RUN apt-get update && \ apt-get -y upgrade && \ apt-get -y install gcc && \ + apt-get -y install build-essential git && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* + + ENV CURL_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt WORKDIR /app diff --git a/dockerfiles/Dockerfile.dev.es b/dockerfiles/Dockerfile.dev.es index 009f9681..f0516031 100644 --- a/dockerfiles/Dockerfile.dev.es +++ b/dockerfiles/Dockerfile.dev.es @@ -4,7 +4,7 @@ FROM python:3.10-slim # update apt pkgs, and install build-essential for ciso8601 RUN apt-get update && \ apt-get -y upgrade && \ - apt-get install -y build-essential git && \ + apt-get -y install build-essential git && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/dockerfiles/Dockerfile.dev.os b/dockerfiles/Dockerfile.dev.os index d9dc8b0a..d6488a74 100644 --- a/dockerfiles/Dockerfile.dev.os +++ b/dockerfiles/Dockerfile.dev.os @@ -4,10 +4,11 @@ FROM python:3.10-slim # update apt pkgs, and install build-essential for ciso8601 RUN apt-get update && \ apt-get -y upgrade && \ - apt-get install -y build-essential && \ + apt-get -y install build-essential && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* +RUN apt-get -y install git # update certs used by Requests ENV CURL_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt diff --git a/dockerfiles/Dockerfile.docs b/dockerfiles/Dockerfile.docs index f1fe63b8..937ae014 100644 --- a/dockerfiles/Dockerfile.docs +++ b/dockerfiles/Dockerfile.docs @@ -1,4 +1,4 @@ -FROM python:3.8-slim +FROM python:3.9-slim # build-essential is required to build a wheel for ciso8601 RUN apt update && apt install -y build-essential diff --git a/stac_fastapi/core/setup.py b/stac_fastapi/core/setup.py index ddf786b6..4a936618 100644 --- a/stac_fastapi/core/setup.py +++ b/stac_fastapi/core/setup.py @@ -10,9 +10,9 @@ "attrs>=23.2.0", "pydantic>=2.4.1,<3.0.0", "stac_pydantic~=3.1.0", - "stac-fastapi.api==5.2.0", - "stac-fastapi.extensions==5.2.0", - "stac-fastapi.types==5.2.0", + "stac-fastapi.types@git+https://github.com/stac-utils/stac-fastapi.git@refs/pull/744/head#subdirectory=stac_fastapi/types", + "stac-fastapi.api@git+https://github.com/stac-utils/stac-fastapi.git@refs/pull/744/head#subdirectory=stac_fastapi/api", + "stac-fastapi.extensions@git+https://github.com/stac-utils/stac-fastapi.git@refs/pull/744/head#subdirectory=stac_fastapi/extensions", "orjson~=3.9.0", "overrides~=7.4.0", "geojson-pydantic~=1.0.0", diff --git a/stac_fastapi/core/stac_fastapi/core/base_database_logic.py b/stac_fastapi/core/stac_fastapi/core/base_database_logic.py index 0043cfb8..50d2062c 100644 --- a/stac_fastapi/core/stac_fastapi/core/base_database_logic.py +++ b/stac_fastapi/core/stac_fastapi/core/base_database_logic.py @@ -1,7 +1,7 @@ """Base database logic.""" import abc -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional class BaseDatabaseLogic(abc.ABC): @@ -29,6 +29,30 @@ async def create_item(self, item: Dict, refresh: bool = False) -> None: """Create an item in the database.""" pass + @abc.abstractmethod + async def merge_patch_item( + self, + collection_id: str, + item_id: str, + item: Dict, + base_url: str, + refresh: bool = True, + ) -> Dict: + """Patch a item in the database follows RF7396.""" + pass + + @abc.abstractmethod + async def json_patch_item( + self, + collection_id: str, + item_id: str, + operations: List, + base_url: str, + refresh: bool = True, + ) -> Dict: + """Patch a item in the database follows RF6902.""" + pass + @abc.abstractmethod async def delete_item( self, item_id: str, collection_id: str, refresh: bool = False @@ -41,6 +65,28 @@ async def create_collection(self, collection: Dict, refresh: bool = False) -> No """Create a collection in the database.""" pass + @abc.abstractmethod + async def merge_patch_collection( + self, + collection_id: str, + collection: Dict, + base_url: str, + refresh: bool = True, + ) -> Dict: + """Patch a collection in the database follows RF7396.""" + pass + + @abc.abstractmethod + async def json_patch_collection( + self, + collection_id: str, + operations: List, + base_url: str, + refresh: bool = True, + ) -> Dict: + """Patch a collection in the database follows RF6902.""" + pass + @abc.abstractmethod async def find_collection(self, collection_id: str) -> Dict: """Find a collection in the database.""" diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index 05212f5b..6e8217e8 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -12,7 +12,7 @@ import orjson from fastapi import HTTPException, Request from overrides import overrides -from pydantic import ValidationError +from pydantic import TypeAdapter, ValidationError from pygeofilter.backends.cql2_json import to_cql2 from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from stac_pydantic import Collection, Item, ItemCollection @@ -42,6 +42,9 @@ logger = logging.getLogger(__name__) +partialItemValidator = TypeAdapter(stac_types.PartialItem) +partialCollectionValidator = TypeAdapter(stac_types.PartialCollection) + @attr.s class CoreClient(AsyncBaseCoreClient): @@ -376,7 +379,7 @@ async def get_item( @staticmethod def _return_date( - interval: Optional[Union[DateTimeType, str]] + interval: Optional[Union[DateTimeType, str]], ) -> Dict[str, Optional[str]]: """ Convert a date interval. @@ -503,6 +506,7 @@ async def get_search( "token": token, "query": orjson.loads(query) if query else query, "q": q, + "datetime": datetime, } if datetime: @@ -765,6 +769,59 @@ async def update_item( return ItemSerializer.db_to_stac(item, base_url) + @overrides + async def patch_item( + self, + collection_id: str, + item_id: str, + patch: Union[stac_types.PartialItem, List[stac_types.PatchOperation]], + **kwargs, + ): + """Patch an item in the collection. + + Args: + collection_id (str): The ID of the collection the item belongs to. + item_id (str): The ID of the item to be updated. + patch (Union[stac_types.PartialItem, List[stac_types.PatchOperation]]): The item data or operations. + kwargs: Other optional arguments, including the request object. + + Returns: + stac_types.Item: The updated item object. + + Raises: + NotFound: If the specified collection is not found in the database. + + """ + base_url = str(kwargs["request"].base_url) + + content_type = kwargs["request"].headers.get("content-type") + + item = None + if isinstance(patch, list) and content_type == "application/json-patch+json": + item = await self.database.json_patch_item( + collection_id=collection_id, + item_id=item_id, + operations=patch, + base_url=base_url, + ) + + if isinstance(patch, dict) and content_type in [ + "application/merge-patch+json", + "application/json", + ]: + patch = partialItemValidator.validate_python(patch) + item = await self.database.merge_patch_item( + collection_id=collection_id, + item_id=item_id, + item=patch, + base_url=base_url, + ) + + if item: + return ItemSerializer.db_to_stac(item, base_url=base_url) + + raise NotImplementedError("Content-Type and body combination not implemented") + @overrides async def delete_item(self, item_id: str, collection_id: str, **kwargs) -> None: """Delete an item from a collection. @@ -846,6 +903,54 @@ async def update_collection( extensions=[type(ext).__name__ for ext in self.database.extensions], ) + def patch_collection( + self, + collection_id: str, + patch: Union[stac_types.PartialCollection, List[stac_types.PatchOperation]], + **kwargs, + ): + """Update a collection. + + Called with `PATCH /collections/{collection_id}` + + Args: + collection_id: id of the collection. + patch: either the partial collection or list of patch operations. + + Returns: + The patched collection. + """ + content_type = kwargs["request"].headers.get("content-type") + base_url = str(kwargs["request"].base_url) + + collection = None + if isinstance(patch, list) and content_type == "application/json-patch+json": + collection = self.database.json_patch_collection( + collection_id=collection_id, + operations=patch, + base_url=base_url, + ) + + if isinstance(patch, dict) and content_type in [ + "application/merge-patch+json", + "application/json", + ]: + patch = partialCollectionValidator.validate_python(patch) + collection = self.database.merge_patch_collection( + collection_id=collection_id, + collection=patch, + base_url=base_url, + ) + + if collection: + return CollectionSerializer.db_to_stac( + collection, + kwargs["request"], + extensions=[type(ext).__name__ for ext in self.database.extensions], + ) + + raise NotImplementedError("Content-Type and body combination not implemented") + @overrides async def delete_collection(self, collection_id: str, **kwargs) -> None: """ diff --git a/stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py b/stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py index d41d763c..31418f25 100644 --- a/stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py +++ b/stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py @@ -369,6 +369,7 @@ async def aggregate( "geometry_geohash_grid_frequency_precision": geometry_geohash_grid_frequency_precision, "geometry_geotile_grid_frequency_precision": geometry_geotile_grid_frequency_precision, "datetime_frequency_interval": datetime_frequency_interval, + "datetime": datetime, } if collection_id: diff --git a/stac_fastapi/core/stac_fastapi/core/models/patch.py b/stac_fastapi/core/stac_fastapi/core/models/patch.py new file mode 100644 index 00000000..7dbfd3db --- /dev/null +++ b/stac_fastapi/core/stac_fastapi/core/models/patch.py @@ -0,0 +1,155 @@ +"""patch helpers.""" + +import re +from typing import Any, Dict, Optional, Union + +from pydantic import BaseModel, computed_field, model_validator + +regex = re.compile(r"([^.' ]*:[^.' ]*)\.?") + + +class ESCommandSet: + """Uses dictionary keys to behaviour of ordered set. + + Yields: + str: Elasticsearch commands + """ + + dict_: Dict[str, None] = {} + + def __init__(self): + """Initialise ESCommandSet instance.""" + self.dict_ = {} + + def add(self, value: str): + """Add command. + + Args: + value (str): value to be added + """ + self.dict_[value] = None + + def remove(self, value: str): + """Remove command. + + Args: + value (str): value to be removed + """ + del self.dict_[value] + + def __iter__(self): + """Iterate Elasticsearch commands. + + Yields: + str: Elasticsearch command + """ + yield from self.dict_.keys() + + +def to_es(string: str): + """Convert patch operation key to Elasticsearch key. + + Args: + string (str): string to be converted + + Returns: + _type_: converted string + """ + if matches := regex.findall(string): + for match in set(matches): + string = re.sub(rf"\.?{match}", f"['{match}']", string) + + return string + + +class ElasticPath(BaseModel): + """Converts a JSON path to an Elasticsearch path. + + Args: + path (str): JSON path to be converted. + + """ + + path: str + nest: Optional[str] = None + partition: Optional[str] = None + key: Optional[str] = None + + es_path: Optional[str] = None + es_nest: Optional[str] = None + es_key: Optional[str] = None + + index_: Optional[int] = None + + @model_validator(mode="before") + @classmethod + def validate_model(cls, data: Any): + """Set optional fields from JSON path. + + Args: + data (Any): input data + """ + data["path"] = data["path"].lstrip("/").replace("/", ".") + data["nest"], data["partition"], data["key"] = data["path"].rpartition(".") + + if data["key"].lstrip("-").isdigit() or data["key"] == "-": + data["index_"] = -1 if data["key"] == "-" else int(data["key"]) + data["path"] = f"{data['nest']}[{data['index_']}]" + data["nest"], data["partition"], data["key"] = data["nest"].rpartition(".") + + data["es_path"] = to_es(data["path"]) + data["es_nest"] = to_es(data["nest"]) + data["es_key"] = to_es(data["key"]) + + return data + + @computed_field # type: ignore[misc] + @property + def index(self) -> Union[int, str, None]: + """Compute location of path. + + Returns: + str: path index + """ + if self.index_ and self.index_ < 0: + + return f"ctx._source.{self.location}.size() - {-self.index_}" + + return self.index_ + + @computed_field # type: ignore[misc] + @property + def location(self) -> str: + """Compute location of path. + + Returns: + str: path location + """ + return self.nest + self.partition + self.key + + @computed_field # type: ignore[misc] + @property + def es_location(self) -> str: + """Compute location of path. + + Returns: + str: path location + """ + if self.es_key and ":" in self.es_key: + return self.es_nest + self.es_key + return self.es_nest + self.partition + self.es_key + + @computed_field # type: ignore[misc] + @property + def variable_name(self) -> str: + """Variable name for scripting. + + Returns: + str: variable name + """ + if self.index is not None: + return f"{self.location.replace('.','_').replace(':','_')}_{self.index}" + + return ( + f"{self.nest.replace('.','_').replace(':','_')}_{self.key.replace(':','_')}" + ) diff --git a/stac_fastapi/core/stac_fastapi/core/utilities.py b/stac_fastapi/core/stac_fastapi/core/utilities.py index d4a35109..84369de4 100644 --- a/stac_fastapi/core/stac_fastapi/core/utilities.py +++ b/stac_fastapi/core/stac_fastapi/core/utilities.py @@ -3,11 +3,18 @@ This module contains functions for transforming geospatial coordinates, such as converting bounding boxes to polygon representations. """ + import logging import os from typing import Any, Dict, List, Optional, Set, Union -from stac_fastapi.types.stac import Item +from stac_fastapi.core.models.patch import ElasticPath, ESCommandSet +from stac_fastapi.types.stac import ( + Item, + PatchAddReplaceTest, + PatchOperation, + PatchRemove, +) MAX_LIMIT = 10000 @@ -217,3 +224,181 @@ def dict_deep_update(merge_to: Dict[str, Any], merge_from: Dict[str, Any]) -> No dict_deep_update(merge_to[k], merge_from[k]) else: merge_to[k] = v + + +def merge_to_operations(data: Dict) -> List: + """Convert merge operation to list of RF6902 operations. + + Args: + data: dictionary to convert. + + Returns: + List: list of RF6902 operations. + """ + operations = [] + + for key, value in data.copy().items(): + + if value is None: + operations.append(PatchRemove(op="remove", path=key)) + + elif isinstance(value, dict): + nested_operations = merge_to_operations(value) + + for nested_operation in nested_operations: + nested_operation.path = f"{key}.{nested_operation.path}" + operations.append(nested_operation) + + else: + operations.append(PatchAddReplaceTest(op="add", path=key, value=value)) + + return operations + + +def check_commands( + commands: ESCommandSet, + op: str, + path: ElasticPath, + from_path: bool = False, +) -> None: + """Add Elasticsearch checks to operation. + + Args: + commands (List[str]): current commands + op (str): the operation of script + path (Dict): path of variable to run operation on + from_path (bool): True if path is a from path + + """ + if path.nest: + commands.add( + f"if (!ctx._source.containsKey('{path.nest}'))" + f"{{Debug.explain('{path.nest} does not exist');}}" + ) + + if path.index or op in ["remove", "replace", "test"] or from_path: + commands.add( + f"if (!ctx._source.{path.es_nest}.containsKey('{path.key}'))" + f"{{Debug.explain('{path.key} does not exist in {path.nest}');}}" + ) + + if from_path and path.index is not None: + commands.add( + f"if ((ctx._source.{path.es_location} instanceof ArrayList" + f" && ctx._source.{path.es_location}.size() < {path.index})" + f" || (!(ctx._source.{path.es_location} instanceof ArrayList)" + f" && !ctx._source.{path.es_location}.containsKey('{path.index}')))" + f"{{Debug.explain('{path.path} does not exist');}}" + ) + + +def remove_commands(commands: ESCommandSet, path: ElasticPath) -> None: + """Remove value at path. + + Args: + commands (List[str]): current commands + path (ElasticPath): Path to value to be removed + + """ + if path.index is not None: + commands.add( + f"def {path.variable_name} = ctx._source.{path.es_location}.remove({path.index});" + ) + + else: + commands.add( + f"def {path.variable_name} = ctx._source.{path.es_nest}.remove('{path.key}');" + ) + + +def add_commands( + commands: ESCommandSet, + operation: PatchOperation, + path: ElasticPath, + from_path: ElasticPath, +) -> None: + """Add value at path. + + Args: + commands (List[str]): current commands + operation (PatchOperation): operation to run + path (ElasticPath): path for value to be added + + """ + if from_path is not None: + value = ( + from_path.variable_name + if operation.op == "move" + else f"ctx._source.{from_path.es_path}" + ) + else: + value = operation.json_value + + if path.index is not None: + commands.add( + f"if (ctx._source.{path.es_location} instanceof ArrayList)" + f"{{ctx._source.{path.es_location}.{'add' if operation.op in ['add', 'move'] else 'set'}({path.index}, {value})}}" + f"else{{ctx._source.{path.es_path} = {value}}}" + ) + + else: + commands.add(f"ctx._source.{path.es_path} = {value};") + + +def test_commands( + commands: ESCommandSet, operation: PatchOperation, path: ElasticPath +) -> None: + """Test value at path. + + Args: + commands (List[str]): current commands + operation (PatchOperation): operation to run + path (ElasticPath): path for value to be tested + """ + commands.add( + f"if (ctx._source.{path.es_path} != {operation.json_value})" + f"{{Debug.explain('Test failed `{path.path}` | " + f"{operation.json_value} != ' + ctx._source.{path.es_path});}}" + ) + + +def operations_to_script(operations: List) -> Dict: + """Convert list of operation to painless script. + + Args: + operations: List of RF6902 operations. + + Returns: + Dict: elasticsearch update script. + """ + commands: ESCommandSet = ESCommandSet() + for operation in operations: + path = ElasticPath(path=operation.path) + from_path = ( + ElasticPath(path=operation.from_) if hasattr(operation, "from_") else None + ) + + check_commands(commands=commands, op=operation.op, path=path) + if from_path is not None: + check_commands( + commands=commands, op=operation.op, path=from_path, from_path=True + ) + + if operation.op in ["remove", "move"]: + remove_path = from_path if from_path else path + remove_commands(commands=commands, path=remove_path) + + if operation.op in ["add", "replace", "copy", "move"]: + add_commands( + commands=commands, operation=operation, path=path, from_path=from_path + ) + + if operation.op == "test": + test_commands(commands=commands, operation=operation, path=path) + + source = "".join(commands) + + return { + "source": source, + "lang": "painless", + } diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 958ee597..b2ef2a51 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -10,7 +10,9 @@ import attr import elasticsearch.helpers as helpers from elasticsearch.dsl import Q, Search +from elasticsearch.exceptions import BadRequestError from elasticsearch.exceptions import NotFoundError as ESNotFoundError +from fastapi import HTTPException from starlette.requests import Request from stac_fastapi.core.base_database_logic import BaseDatabaseLogic @@ -31,13 +33,26 @@ ) from stac_fastapi.core.extensions import filter from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer -from stac_fastapi.core.utilities import MAX_LIMIT, bbox2polygon, validate_refresh +from stac_fastapi.core.utilities import ( + MAX_LIMIT, + bbox2polygon, + merge_to_operations, + operations_to_script, + validate_refresh, +) from stac_fastapi.elasticsearch.config import AsyncElasticsearchSettings from stac_fastapi.elasticsearch.config import ( ElasticsearchSettings as SyncElasticsearchSettings, ) from stac_fastapi.types.errors import ConflictError, NotFoundError -from stac_fastapi.types.stac import Collection, Item +from stac_fastapi.types.links import resolve_links +from stac_fastapi.types.stac import ( + Collection, + Item, + PartialCollection, + PartialItem, + PatchOperation, +) logger = logging.getLogger(__name__) @@ -923,6 +938,135 @@ async def create_item( refresh=refresh, ) + async def merge_patch_item( + self, + collection_id: str, + item_id: str, + item: PartialItem, + base_url: str, + refresh: bool = True, + ) -> Item: + """Database logic for merge patching an item following RF7396. + + Args: + collection_id(str): Collection that item belongs to. + item_id(str): Id of item to be patched. + item (PartialItem): The partial item to be updated. + base_url: (str): The base URL used for constructing URLs for the item. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Returns: + patched item. + """ + operations = merge_to_operations(item) + + return await self.json_patch_item( + collection_id=collection_id, + item_id=item_id, + operations=operations, + base_url=base_url, + refresh=refresh, + ) + + async def json_patch_item( + self, + collection_id: str, + item_id: str, + operations: List[PatchOperation], + base_url: str, + refresh: bool = True, + ) -> Item: + """Database logic for json patching an item following RF6902. + + Args: + collection_id(str): Collection that item belongs to. + item_id(str): Id of item to be patched. + operations (list): List of operations to run. + base_url (str): The base URL used for constructing URLs for the item. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Returns: + patched item. + """ + new_item_id = None + new_collection_id = None + script_operations = [] + + for operation in operations: + if operation.path in ["collection", "id"] and operation.op in [ + "add", + "replace", + ]: + + if operation.path == "collection" and collection_id != operation.value: + await self.check_collection_exists(collection_id=operation.value) + new_collection_id = operation.value + + if operation.path == "id" and item_id != operation.value: + new_item_id = operation.value + + else: + script_operations.append(operation) + + script = operations_to_script(script_operations) + + try: + await self.client.update( + index=index_alias_by_collection_id(collection_id), + id=mk_item_id(item_id, collection_id), + script=script, + refresh=True, + ) + + except BadRequestError as exc: + raise HTTPException( + status_code=400, detail=exc.info["error"]["caused_by"] + ) from exc + + item = await self.get_one_item(collection_id, item_id) + + if new_collection_id: + await self.client.reindex( + body={ + "dest": {"index": f"{ITEMS_INDEX_PREFIX}{new_collection_id}"}, + "source": { + "index": f"{ITEMS_INDEX_PREFIX}{collection_id}", + "query": {"term": {"id": {"value": item_id}}}, + }, + "script": { + "lang": "painless", + "source": ( + f"""ctx._id = ctx._id.replace('{collection_id}', '{new_collection_id}');""" + f"""ctx._source.collection = '{new_collection_id}';""" + ), + }, + }, + wait_for_completion=True, + refresh=True, + ) + + await self.delete_item( + item_id=item_id, + collection_id=collection_id, + refresh=refresh, + ) + + item["collection"] = new_collection_id + collection_id = new_collection_id + + if new_item_id: + item["id"] = new_item_id + item = await self.async_prep_create_item(item=item, base_url=base_url) + await self.create_item(item=item, refresh=False) + + await self.delete_item( + item_id=item_id, + collection_id=collection_id, + refresh=refresh, + ) + + return item + async def delete_item(self, item_id: str, collection_id: str, **kwargs: Any): """Delete a single item from the database. @@ -1125,6 +1269,95 @@ async def update_collection( refresh=refresh, ) + async def merge_patch_collection( + self, + collection_id: str, + collection: PartialCollection, + base_url: str, + refresh: bool = True, + ) -> Collection: + """Database logic for merge patching a collection following RF7396. + + Args: + collection_id(str): Id of collection to be patched. + collection (PartialCollection): The partial collection to be updated. + base_url: (str): The base URL used for constructing links. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + + Returns: + patched collection. + """ + operations = merge_to_operations(collection) + + return await self.json_patch_collection( + collection_id=collection_id, + operations=operations, + base_url=base_url, + refresh=refresh, + ) + + async def json_patch_collection( + self, + collection_id: str, + operations: List[PatchOperation], + base_url: str, + refresh: bool = True, + ) -> Collection: + """Database logic for json patching a collection following RF6902. + + Args: + collection_id(str): Id of collection to be patched. + operations (list): List of operations to run. + base_url (str): The base URL used for constructing links. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Returns: + patched collection. + """ + new_collection_id = None + script_operations = [] + + for operation in operations: + if ( + operation.op in ["add", "replace"] + and operation.path == "collection" + and collection_id != operation.value + ): + new_collection_id = operation.value + + else: + script_operations.append(operation) + + script = operations_to_script(script_operations) + + try: + await self.client.update( + index=COLLECTIONS_INDEX, + id=collection_id, + script=script, + refresh=True, + ) + + except BadRequestError as exc: + raise HTTPException( + status_code=400, detail=exc.info["error"]["caused_by"] + ) from exc + + collection = await self.find_collection(collection_id) + + if new_collection_id: + collection["id"] = new_collection_id + collection["links"] = resolve_links([], base_url) + + await self.update_collection( + collection_id=collection_id, + collection=collection, + refresh=False, + ) + + return collection + async def delete_collection(self, collection_id: str, **kwargs: Any): """Delete a collection from the database. @@ -1148,28 +1381,15 @@ async def delete_collection(self, collection_id: str, **kwargs: Any): # Ensure kwargs is a dictionary kwargs = kwargs or {} - # Verify that the collection exists - await self.find_collection(collection_id=collection_id) - - # Resolve the `refresh` parameter refresh = kwargs.get("refresh", self.async_settings.database_refresh) refresh = validate_refresh(refresh) - # Log the deletion attempt - logger.info(f"Deleting collection {collection_id} with refresh={refresh}") - - # Delete the collection from the database + # Verify that the collection exists + await self.find_collection(collection_id=collection_id) await self.client.delete( index=COLLECTIONS_INDEX, id=collection_id, refresh=refresh ) - - # Delete the item index for the collection - try: - await delete_item_index(collection_id) - except Exception as e: - logger.error( - f"Failed to delete item index for collection {collection_id}: {e}" - ) + await delete_item_index(collection_id) async def bulk_async( self, diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index 71ab9275..c48fa925 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type import attr +from fastapi import HTTPException from opensearchpy import exceptions, helpers from opensearchpy.helpers.query import Q from opensearchpy.helpers.search import Search @@ -31,13 +32,26 @@ ) from stac_fastapi.core.extensions import filter from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer -from stac_fastapi.core.utilities import MAX_LIMIT, bbox2polygon, validate_refresh +from stac_fastapi.core.utilities import ( + MAX_LIMIT, + bbox2polygon, + merge_to_operations, + operations_to_script, + validate_refresh, +) from stac_fastapi.opensearch.config import ( AsyncOpensearchSettings as AsyncSearchSettings, ) from stac_fastapi.opensearch.config import OpensearchSettings as SyncSearchSettings from stac_fastapi.types.errors import ConflictError, NotFoundError -from stac_fastapi.types.stac import Collection, Item +from stac_fastapi.types.links import resolve_links +from stac_fastapi.types.stac import ( + Collection, + Item, + PartialCollection, + PartialItem, + PatchOperation, +) logger = logging.getLogger(__name__) @@ -937,6 +951,129 @@ async def create_item( refresh=refresh, ) + async def merge_patch_item( + self, + collection_id: str, + item_id: str, + item: PartialItem, + base_url: str, + refresh: bool = True, + ) -> Item: + """Database logic for merge patching an item following RF7396. + + Args: + collection_id(str): Collection that item belongs to. + item_id(str): Id of item to be patched. + item (PartialItem): The partial item to be updated. + base_url: (str): The base URL used for constructing URLs for the item. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Returns: + patched item. + """ + operations = merge_to_operations(item) + + return await self.json_patch_item( + collection_id=collection_id, + item_id=item_id, + operations=operations, + base_url=base_url, + refresh=refresh, + ) + + async def json_patch_item( + self, + collection_id: str, + item_id: str, + operations: List[PatchOperation], + base_url: str, + refresh: bool = True, + ) -> Item: + """Database logic for json patching an item following RF6902. + + Args: + collection_id(str): Collection that item belongs to. + item_id(str): Id of item to be patched. + operations (list): List of operations to run. + base_url (str): The base URL used for constructing URLs for the item. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Returns: + patched item. + """ + new_item_id = None + new_collection_id = None + script_operations = [] + + for operation in operations: + if operation.path in ["collection", "id"] and operation.op in [ + "add", + "replace", + ]: + + if operation.path == "collection" and collection_id != operation.value: + await self.check_collection_exists(collection_id=operation.value) + new_collection_id = operation.value + + if operation.path == "id" and item_id != operation.value: + new_item_id = operation.value + + else: + script_operations.append(operation) + + script = operations_to_script(script_operations) + + try: + await self.client.update( + index=index_alias_by_collection_id(collection_id), + id=mk_item_id(item_id, collection_id), + body={"script": script}, + refresh=True, + ) + + except exceptions.RequestError as exc: + raise HTTPException( + status_code=400, detail=exc.info["error"]["caused_by"] + ) from exc + + item = await self.get_one_item(collection_id, item_id) + + if new_collection_id: + await self.client.reindex( + body={ + "dest": {"index": f"{ITEMS_INDEX_PREFIX}{new_collection_id}"}, + "source": { + "index": f"{ITEMS_INDEX_PREFIX}{collection_id}", + "query": {"term": {"id": {"value": item_id}}}, + }, + "script": { + "lang": "painless", + "source": ( + f"""ctx._id = ctx._id.replace('{collection_id}', '{new_collection_id}');""" + f"""ctx._source.collection = '{new_collection_id}';""" + ), + }, + }, + wait_for_completion=True, + refresh=True, + ) + item["collection"] = new_collection_id + + if new_item_id: + item["id"] = new_item_id + item = await self.prep_create_item(item=item, base_url=base_url) + await self.create_item(item=item, refresh=False) + + if new_collection_id or new_item_id: + + await self.delete_item( + item_id=item_id, + collection_id=collection_id, + refresh=refresh, + ) + + return item + async def delete_item(self, item_id: str, collection_id: str, **kwargs: Any): """Delete a single item from the database. @@ -1113,6 +1250,93 @@ async def update_collection( refresh=refresh, ) + async def merge_patch_collection( + self, + collection_id: str, + collection: PartialCollection, + base_url: str, + refresh: bool = True, + ) -> Collection: + """Database logic for merge patching a collection following RF7396. + + Args: + collection_id(str): Id of collection to be patched. + collection (PartialCollection): The partial collection to be updated. + base_url: (str): The base URL used for constructing links. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + + Returns: + patched collection. + """ + operations = merge_to_operations(collection) + + return await self.json_patch_collection( + collection_id=collection_id, + operations=operations, + base_url=base_url, + refresh=refresh, + ) + + async def json_patch_collection( + self, + collection_id: str, + operations: List[PatchOperation], + base_url: str, + refresh: bool = True, + ) -> Collection: + """Database logic for json patching a collection following RF6902. + + Args: + collection_id(str): Id of collection to be patched. + operations (list): List of operations to run. + base_url (str): The base URL used for constructing links. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to True. + + Returns: + patched collection. + """ + new_collection_id = None + script_operations = [] + + for operation in operations: + if ( + operation.op in ["add", "replace"] + and operation.path == "collection" + and collection_id != operation["value"] + ): + new_collection_id = operation["value"] + + else: + script_operations.append(operation) + + script = operations_to_script(script_operations) + + try: + await self.client.update( + index=COLLECTIONS_INDEX, + id=collection_id, + body={"script": script}, + refresh=True, + ) + + except exceptions.RequestError as exc: + + raise HTTPException( + status_code=400, detail=exc.info["error"]["caused_by"] + ) from exc + + collection = await self.find_collection(collection_id) + + if new_collection_id: + collection["id"] = new_collection_id + collection["links"] = resolve_links([], base_url) + await self.update_collection( + collection_id=collection_id, collection=collection, refresh=False + ) + + return collection + async def delete_collection(self, collection_id: str, **kwargs: Any): """Delete a collection from the database. diff --git a/stac_fastapi/tests/api/test_api.py b/stac_fastapi/tests/api/test_api.py index 807da5e4..301c5348 100644 --- a/stac_fastapi/tests/api/test_api.py +++ b/stac_fastapi/tests/api/test_api.py @@ -7,7 +7,7 @@ from stac_fastapi.types.errors import ConflictError -from ..conftest import create_collection, create_item +from ..conftest import MockRequest, create_collection, create_item ROUTES = { "GET /_mgmt/ping", @@ -33,7 +33,9 @@ "POST /collections", "POST /collections/{collection_id}/items", "PUT /collections/{collection_id}", + "PATCH /collections/{collection_id}", "PUT /collections/{collection_id}/items/{item_id}", + "PATCH /collections/{collection_id}/items/{item_id}", "GET /aggregations", "GET /aggregate", "POST /aggregations", @@ -616,6 +618,137 @@ async def test_bbox_3d(app_client, ctx): assert len(resp_json["features"]) == 1 +@pytest.mark.asyncio +async def test_patch_json_collection(app_client, ctx): + data = { + "id": "new_id", + "summaries": {"hello": "world", "gsd": [50], "instruments": None}, + } + + resp = await app_client.patch(f"/collections/{ctx.collection['id']}", json=data) + + assert resp.status_code == 200 + + new_resp = await app_client.get("/collections/new_id") + old_resp = await app_client.get(f"/collections/{ctx.collection['id']}") + + assert new_resp.status_code == 200 + assert old_resp.status_code == 404 + + new_resp_json = new_resp.json() + + assert new_resp_json["id"] == "new_id" + assert new_resp_json["summaries"]["hello"] == "world" + assert "instruments" not in new_resp_json["summaries"] + assert new_resp_json["summaries"]["gsd"] == [50] + assert new_resp_json["summaries"]["platform"] == ["landsat-8"] + + +@pytest.mark.asyncio +async def test_patch_operations_collection(app_client, ctx): + operations = [ + {"op": "add", "path": "/summaries/hello", "value": "world"}, + {"op": "replace", "path": "/summaries/gsd", "value": [50]}, + { + "op": "move", + "path": "/summaries/instruments", + "from": "/summaries/instrument", + }, + {"op": "copy", "path": "license", "from": "/summaries/license"}, + ] + + resp = await app_client.patch( + f"/collections/{ctx.item['collection']}", + json=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + assert resp.status_code == 200 + + new_resp = await app_client.get( + f"/collections/{ctx.item['collection']}/{ctx.item['id']}" + ) + + assert new_resp.status_code == 200 + + new_resp_json = new_resp.json() + + assert new_resp_json["summaries"]["hello"] == "world" + assert "instruments" not in new_resp_json["summaries"] + assert new_resp_json["summaries"]["gsd"] == [50] + assert new_resp_json["license"] == "PDDL-1.0" + assert new_resp_json["summaries"]["license"] == "PDDL-1.0" + assert new_resp_json["summaries"]["instrument"] == ["oli", "tirs"] + assert new_resp_json["summaries"]["platform"] == ["landsat-8"] + + +@pytest.mark.asyncio +async def test_patch_json_item(app_client, ctx): + + data = { + "id": "new_id", + "properties": {"hello": "world", "proj:epsg": 1000, "landsat:column": None}, + } + + resp = await app_client.patch( + f"/collections/{ctx.item['collection']}/{ctx.item['id']}", json=data + ) + + assert resp.status_code == 200 + + new_resp = await app_client.get(f"/collections/{ctx.item['collection']}/new_id") + old_resp = await app_client.get( + f"/collections/{ctx.item['collection']}/{ctx.item['id']}" + ) + + assert new_resp.status_code == 200 + assert old_resp.status_code == 404 + + new_resp_json = new_resp.json() + + assert new_resp_json["id"] == "new_id" + assert new_resp_json["properties"]["hello"] == "world" + assert "landsat:column" not in new_resp_json["properties"] + assert new_resp_json["properties"]["proj:epsg"] == 1000 + assert new_resp_json["properties"]["platform"] == "landsat-8" + + +@pytest.mark.asyncio +async def test_patch_operations_item(app_client, ctx): + operations = [ + {"op": "add", "path": "/properties/hello", "value": "world"}, + {"op": "remove", "path": "/properties/landsat:column"}, + {"op": "replace", "path": "/properties/proj:epsg", "value": 1000}, + {"op": "move", "path": "/properties/foo", "from": "/properties/instrument"}, + {"op": "copy", "path": "properties/bar", "from": "/properties/height"}, + ] + + resp = await app_client.patch( + f"/collections/{ctx.item['collection']}/{ctx.item['id']}", + json=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + assert resp.status_code == 200 + + new_resp = await app_client.get( + f"/collections/{ctx.item['collection']}/{ctx.item['id']}" + ) + + assert new_resp.status_code == 200 + + new_resp_json = new_resp.json() + + assert new_resp_json["properties"]["hello"] == "world" + assert "landsat:column" not in new_resp_json["properties"] + assert "instrument" not in new_resp_json["properties"] + assert new_resp_json["properties"]["proj:epsg"] == 1000 + assert new_resp_json["properties"]["foo"] == "OLI_TIRS" + assert new_resp_json["properties"]["bar"] == 2500 + assert new_resp_json["properties"]["height"] == 2500 + assert new_resp_json["properties"]["platform"] == "landsat-8" + + @pytest.mark.asyncio async def test_search_line_string_intersects(app_client, ctx): line = [[150.04, -33.14], [150.22, -33.89]] diff --git a/stac_fastapi/tests/clients/test_es_os.py b/stac_fastapi/tests/clients/test_es_os.py index 0f200826..5e909015 100644 --- a/stac_fastapi/tests/clients/test_es_os.py +++ b/stac_fastapi/tests/clients/test_es_os.py @@ -3,9 +3,11 @@ from typing import Callable import pytest +from fastapi import HTTPException from stac_pydantic import Item, api from stac_fastapi.types.errors import ConflictError, NotFoundError +from stac_fastapi.types.stac import PatchAddReplaceTest, PatchMoveCopy, PatchRemove from ..conftest import MockRequest @@ -236,6 +238,361 @@ async def test_update_item(ctx, core_client, txn_client): assert updated_item["properties"]["foo"] == "bar" +@pytest.mark.asyncio +async def test_merge_patch_item_add(ctx, core_client, txn_client): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch={"properties": {"foo": "bar", "ext:hello": "world"}}, + request=MockRequest, + ) + + updated_item = await core_client.get_item( + item_id, collection_id, request=MockRequest + ) + assert updated_item["properties"]["foo"] == "bar" + assert updated_item["properties"]["ext:hello"] == "world" + + +@pytest.mark.asyncio +async def test_merge_patch_item_remove(ctx, core_client, txn_client): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch={"properties": {"gsd": None, "proj:epsg": None}}, + request=MockRequest, + ) + + updated_item = await core_client.get_item( + item_id, collection_id, request=MockRequest + ) + assert "gsd" not in updated_item["properties"] + assert "proj:epsg" not in updated_item["properties"] + + +@pytest.mark.asyncio +async def test_json_patch_item_add(ctx, core_client, txn_client): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "add", "path": "/properties/foo", "value": "bar"} + ), + PatchAddReplaceTest.model_validate( + {"op": "add", "path": "/properties/ext:hello", "value": "world"} + ), + PatchAddReplaceTest.model_validate( + {"op": "add", "path": "/properties/area/1", "value": 10} + ), + ] + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_item = await core_client.get_item( + item_id, collection_id, request=MockRequest + ) + + assert updated_item["properties"]["foo"] == "bar" + assert updated_item["properties"]["ext:hello"] == "world" + assert updated_item["properties"]["area"] == [2500, 10, -200] + + +@pytest.mark.asyncio +async def test_json_patch_item_replace(ctx, core_client, txn_client): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "replace", "path": "/properties/gsd", "value": 100} + ), + PatchAddReplaceTest.model_validate( + {"op": "replace", "path": "/properties/proj:epsg", "value": 12345} + ), + PatchAddReplaceTest.model_validate( + {"op": "replace", "path": "/properties/area/1", "value": 50} + ), + ] + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_item = await core_client.get_item( + item_id, collection_id, request=MockRequest + ) + + assert updated_item["properties"]["gsd"] == 100 + assert updated_item["properties"]["proj:epsg"] == 12345 + assert updated_item["properties"]["area"] == [2500, 50] + + +@pytest.mark.asyncio +async def test_json_patch_item_test(ctx, core_client, txn_client): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "test", "path": "/properties/gsd", "value": 15} + ), + PatchAddReplaceTest.model_validate( + {"op": "test", "path": "/properties/proj:epsg", "value": 32756} + ), + PatchAddReplaceTest.model_validate( + {"op": "test", "path": "/properties/area/1", "value": -200} + ), + ] + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_item = await core_client.get_item( + item_id, collection_id, request=MockRequest + ) + + assert updated_item["properties"]["gsd"] == 15 + assert updated_item["properties"]["proj:epsg"] == 32756 + assert updated_item["properties"]["area"][1] == -200 + + +@pytest.mark.asyncio +async def test_json_patch_item_move(ctx, core_client, txn_client): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchMoveCopy.model_validate( + {"op": "move", "path": "/properties/foo", "from": "/properties/gsd"} + ), + PatchMoveCopy.model_validate( + {"op": "move", "path": "/properties/bar", "from": "/properties/proj:epsg"} + ), + PatchMoveCopy.model_validate( + {"op": "move", "path": "/properties/area/0", "from": "/properties/area/1"} + ), + ] + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_item = await core_client.get_item( + item_id, collection_id, request=MockRequest + ) + + assert updated_item["properties"]["foo"] == 15 + assert "gsd" not in updated_item["properties"] + assert updated_item["properties"]["bar"] == 32756 + assert "proj:epsg" not in updated_item["properties"] + assert updated_item["properties"]["area"] == [-200, 2500] + + +@pytest.mark.asyncio +async def test_json_patch_item_copy(ctx, core_client, txn_client): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchMoveCopy.model_validate( + {"op": "copy", "path": "/properties/foo", "from": "/properties/gsd"} + ), + PatchMoveCopy.model_validate( + {"op": "copy", "path": "/properties/bar", "from": "/properties/proj:epsg"} + ), + PatchMoveCopy.model_validate( + {"op": "copy", "path": "/properties/area/0", "from": "/properties/area/1"} + ), + ] + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_item = await core_client.get_item( + item_id, collection_id, request=MockRequest + ) + + assert updated_item["properties"]["foo"] == updated_item["properties"]["gsd"] + assert updated_item["properties"]["bar"] == updated_item["properties"]["proj:epsg"] + assert ( + updated_item["properties"]["area"][0] == updated_item["properties"]["area"][1] + ) + + +@pytest.mark.asyncio +async def test_json_patch_item_remove(ctx, core_client, txn_client): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchRemove.model_validate({"op": "remove", "path": "/properties/gsd"}), + PatchRemove.model_validate({"op": "remove", "path": "/properties/proj:epsg"}), + PatchRemove.model_validate({"op": "remove", "path": "/properties/area/1"}), + ] + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_item = await core_client.get_item( + item_id, collection_id, request=MockRequest + ) + + assert "gsd" not in updated_item["properties"] + assert "proj:epsg" not in updated_item["properties"] + assert updated_item["properties"]["area"] == [2500] + + +@pytest.mark.asyncio +async def test_json_patch_item_test_wrong_value(ctx, core_client, txn_client): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "test", "path": "/properties/platform", "value": "landsat-9"} + ), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) + + +@pytest.mark.asyncio +async def test_json_patch_item_replace_property_does_not_exists( + ctx, core_client, txn_client +): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "replace", "path": "/properties/foo", "value": "landsat-9"} + ), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) + + +@pytest.mark.asyncio +async def test_json_patch_item_remove_property_does_not_exists( + ctx, core_client, txn_client +): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchRemove.model_validate({"op": "remove", "path": "/properties/foo"}), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) + + +@pytest.mark.asyncio +async def test_json_patch_item_move_property_does_not_exists( + ctx, core_client, txn_client +): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchMoveCopy.model_validate( + {"op": "move", "path": "/properties/bar", "from": "/properties/foo"} + ), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) + + +@pytest.mark.asyncio +async def test_json_patch_item_copy_property_does_not_exists( + ctx, core_client, txn_client +): + item = ctx.item + collection_id = item["collection"] + item_id = item["id"] + operations = [ + PatchMoveCopy.model_validate( + {"op": "copy", "path": "/properties/bar", "from": "/properties/foo"} + ), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_item( + collection_id=collection_id, + item_id=item_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) + + @pytest.mark.asyncio async def test_update_geometry(ctx, core_client, txn_client): new_coordinates = [ @@ -286,3 +643,291 @@ async def test_landing_page_no_collection_title(ctx, core_client, txn_client, ap for link in landing_page["links"]: if link["href"].split("/")[-1] == ctx.collection["id"]: assert link["title"] + + +@pytest.mark.asyncio +async def test_merge_patch_collection_add(ctx, core_client, txn_client): + collection = ctx.collection + collection_id = collection["id"] + + await txn_client.patch_collection( + collection_id=collection_id, + patch={"summaries": {"foo": "bar", "hello": "world"}}, + request=MockRequest, + ) + + updated_collection = await core_client.get_collection( + collection_id, request=MockRequest + ) + assert updated_collection["summaries"]["foo"] == "bar" + assert updated_collection["summaries"]["hello"] == "world" + + +@pytest.mark.asyncio +async def test_merge_patch_collection_remove(ctx, core_client, txn_client): + collection = ctx.collection + collection_id = collection["id"] + await txn_client.patch_collection( + collection_id=collection_id, + patch={"summaries": {"gsd": None}}, + request=MockRequest, + ) + + updated_collection = await core_client.get_collection( + collection_id, request=MockRequest + ) + assert "gsd" not in updated_collection["summaries"] + + +@pytest.mark.asyncio +async def test_json_patch_collection_add(ctx, core_client, txn_client): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "add", "path": "/summaries/foo", "value": "bar"}, + ), + PatchAddReplaceTest.model_validate( + {"op": "add", "path": "/summaries/gsd/1", "value": 100}, + ), + ] + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_collection = await core_client.get_collection( + collection_id, request=MockRequest + ) + + assert updated_collection["summaries"]["foo"] == "bar" + assert updated_collection["summaries"]["gsd"] == [30, 100] + + +@pytest.mark.asyncio +async def test_json_patch_collection_replace(ctx, core_client, txn_client): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "replace", "path": "/summaries/gsd", "value": [100]} + ), + ] + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_collection = await core_client.get_collection( + collection_id, request=MockRequest + ) + + assert updated_collection["summaries"]["gsd"] == [100] + + +@pytest.mark.asyncio +async def test_json_patch_collection_test(ctx, core_client, txn_client): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "test", "path": "/summaries/gsd", "value": [30]} + ), + ] + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_collection = await core_client.get_collection( + collection_id, request=MockRequest + ) + + assert updated_collection["summaries"]["gsd"] == [30] + + +@pytest.mark.asyncio +async def test_json_patch_collection_move(ctx, core_client, txn_client): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchMoveCopy.model_validate( + {"op": "move", "path": "/summaries/bar", "from": "/summaries/gsd"} + ), + ] + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_collection = await core_client.get_collection( + collection_id, request=MockRequest + ) + + assert updated_collection["summaries"]["bar"] == [30] + assert "gsd" not in updated_collection["summaries"] + + +@pytest.mark.asyncio +async def test_json_patch_collection_copy(ctx, core_client, txn_client): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchMoveCopy.model_validate( + {"op": "copy", "path": "/summaries/foo", "from": "/summaries/gsd"} + ), + ] + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_collection = await core_client.get_collection( + collection_id, request=MockRequest + ) + + assert ( + updated_collection["summaries"]["foo"] == updated_collection["summaries"]["gsd"] + ) + + +@pytest.mark.asyncio +async def test_json_patch_collection_remove(ctx, core_client, txn_client): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchRemove.model_validate({"op": "remove", "path": "/summaries/gsd"}), + ] + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest(headers={"Content-type": "application/json-patch+json"}), + ) + + updated_collection = await core_client.get_collection( + collection_id, request=MockRequest + ) + + assert "gsd" not in updated_collection["summaries"] + + +@pytest.mark.asyncio +async def test_json_patch_collection_test_wrong_value(ctx, core_client, txn_client): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "test", "path": "/summaries/platform", "value": "landsat-9"} + ), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) + + +@pytest.mark.asyncio +async def test_json_patch_collection_replace_property_does_not_exists( + ctx, core_client, txn_client +): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchAddReplaceTest.model_validate( + {"op": "replace", "path": "/summaries/foo", "value": "landsat-9"} + ), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) + + +@pytest.mark.asyncio +async def test_json_patch_collection_remove_property_does_not_exists( + ctx, core_client, txn_client +): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchRemove.model_validate({"op": "remove", "path": "/summaries/foo"}), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) + + +@pytest.mark.asyncio +async def test_json_patch_collection_move_property_does_not_exists( + ctx, core_client, txn_client +): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchMoveCopy.model_validate( + {"op": "move", "path": "/summaries/bar", "from": "/summaries/foo"} + ), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) + + +@pytest.mark.asyncio +async def test_json_patch_collection_copy_property_does_not_exists( + ctx, core_client, txn_client +): + collection = ctx.collection + collection_id = collection["id"] + operations = [ + PatchMoveCopy.model_validate( + {"op": "copy", "path": "/summaries/bar", "from": "/summaries/foo"} + ), + ] + + with pytest.raises(HTTPException): + + await txn_client.patch_collection( + collection_id=collection_id, + patch=operations, + request=MockRequest( + headers={"Content-type": "application/json-patch+json"} + ), + ) diff --git a/stac_fastapi/tests/conftest.py b/stac_fastapi/tests/conftest.py index 066b014d..5d172f40 100644 --- a/stac_fastapi/tests/conftest.py +++ b/stac_fastapi/tests/conftest.py @@ -71,6 +71,7 @@ def __init__(self, item, collection): class MockRequest: base_url = "http://test-server" url = "http://test-server/test" + headers = {} query_params = {} def __init__( @@ -79,11 +80,13 @@ def __init__( url: str = "XXXX", app: Optional[Any] = None, query_params: Dict[str, Any] = {"limit": "10"}, + headers: Dict[str, Any] = {"Content-type": "application/json"}, ): self.method = method self.url = url self.app = app self.query_params = query_params + self.headers = headers class TestSettings(AsyncSettings):