Skip to content

Commit e32d0fc

Browse files
improve query validation (#49)
* better query validation * bugfix * add test
1 parent fce40e5 commit e32d0fc

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

stac_api/models/schemas.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from shapely.geometry import shape
1313

1414
from geojson_pydantic.geometries import Polygon
15-
from pydantic import Field, root_validator
15+
from pydantic import Field, ValidationError, root_validator
16+
from pydantic.error_wrappers import ErrorWrapper
1617
from stac_api import config
1718
from stac_api.models.decompose import CollectionGetter, ItemGetter
1819
from stac_pydantic import Collection as CollectionBase
@@ -177,12 +178,30 @@ class STACSearch(Search):
177178
query: Optional[Dict[Queryables, Dict[Operator, Any]]]
178179
token: Optional[str] = None
179180

181+
@root_validator(pre=True)
182+
def validate_query_fields(cls, values: Dict) -> Dict:
183+
"""validate query fields"""
184+
if "query" in values and values["query"]:
185+
queryable_fields = Queryables.__members__.values()
186+
for field_name in values["query"]:
187+
if field_name not in queryable_fields:
188+
raise ValidationError(
189+
[
190+
ErrorWrapper(
191+
ValueError(f"Cannot search on field: {field_name}"),
192+
"STACSearch",
193+
)
194+
],
195+
STACSearch,
196+
)
197+
return values
198+
180199
@root_validator
181200
def include_query_fields(cls, values: Dict) -> Dict:
182201
"""
183202
Root validator to ensure query fields are included in the API response
184203
"""
185-
if values["query"]:
204+
if "query" in values and values["query"]:
186205
query_include = set(
187206
[
188207
k.value

tests/resources/test_item.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,3 +696,9 @@ def test_get_missing_item(app_client, load_test_data):
696696
test_coll = load_test_data("test_collection.json")
697697
resp = app_client.get(f"/collections/{test_coll['id']}/items/invalid-item")
698698
assert resp.status_code == 404
699+
700+
701+
def test_search_invalid_query_field(app_client):
702+
body = {"query": {"gsd": {"lt": 100}, "invalid-field": {"eq": 50}}}
703+
resp = app_client.post("/search", json=body)
704+
assert resp.status_code == 422

0 commit comments

Comments
 (0)