Skip to content

Commit b418a57

Browse files
authored
Coerce cql2 style to match HTTP method (#804)
Use cql2 library to convert between cql2-text and cql2-json
1 parent b81cbee commit b418a57

File tree

5 files changed

+201
-15
lines changed

5 files changed

+201
-15
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
77

88
## [Unreleased]
99

10+
### Added
11+
12+
- Coerce cql2 style to match HTTP method using `cql2` library ([#804](https://github.com/stac-utils/pystac-client/pull/804))
13+
1014
### Fixed
1115

12-
- Fix usage documentation of `ItemSearch`
16+
- Fix usage documentation of `ItemSearch` ([#790](https://github.com/stac-utils/pystac-client/pull/790))
1317
- Fix fields argument to CLI ([#797](https://github.com/stac-utils/pystac-client/pull/797))
1418
- Clarify recursive behaviour of the `get_items` method in the method docstring ([#800](https://github.com/stac-utils/pystac-client/pull/800))
1519

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ stac-client = "pystac_client.cli:cli"
3939
dev = [
4040
"codespell~=2.4.0",
4141
"coverage~=7.2",
42+
"cql2>=0.3.7",
4243
"doc8~=1.1.1",
4344
"importlib-metadata~=8.0",
4445
"mypy~=1.2",
@@ -55,7 +56,7 @@ dev = [
5556
"tomli~=2.0; python_version<'3.11'",
5657
"types-python-dateutil>=2.8.19,<2.10.0",
5758
"types-requests~=2.32.0",
58-
"urllib3>=2.0,<2.3.0", # v2.3.0 breaks VCR, b/c https://github.com/urllib3/urllib3/pull/3489
59+
"urllib3>=2.0,<2.3.0", # v2.3.0 breaks VCR, b/c https://github.com/urllib3/urllib3/pull/3489
5960
]
6061
docs = [
6162
"Sphinx~=8.0",

pystac_client/item_search.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(
159159

160160
self.method = method
161161
self.modifier = modifier
162+
162163
params = {
163164
"limit": limit,
164165
"bbox": self._format_bbox(bbox),
@@ -167,8 +168,8 @@ def __init__(
167168
"collections": self._format_collections(collections),
168169
"intersects": self._format_intersects(intersects),
169170
"query": self._format_query(query),
170-
"filter": self._format_filter(filter),
171-
"filter-lang": self._format_filter_lang(filter, filter_lang),
171+
"filter": self._format_filter(method, filter_lang, filter),
172+
"filter-lang": self._format_filter_lang(method, filter, filter_lang),
172173
"sortby": self._format_sortby(sortby),
173174
"fields": self._format_fields(fields),
174175
"q": q,
@@ -204,6 +205,8 @@ def _clean_params_for_get_request(self) -> dict[str, Any]:
204205
params["sortby"] = self._sortby_dict_to_str(params["sortby"])
205206
if "fields" in params:
206207
params["fields"] = self._fields_dict_to_str(params["fields"])
208+
if "filter" in params and isinstance(params["filter"], dict):
209+
params["filter"] = json.dumps(params["filter"])
207210
return params
208211

209212
def url_with_parameters(self) -> str:
@@ -266,29 +269,69 @@ def _format_query(self, value: QueryLike | None) -> dict[str, Any] | None:
266269

267270
@staticmethod
268271
def _format_filter_lang(
269-
_filter: FilterLike | None, value: FilterLangLike | None
272+
method: str | None,
273+
_filter: FilterLike | None,
274+
value: FilterLangLike | None,
270275
) -> str | None:
271276
if _filter is None:
272277
return None
273278

274279
if value is not None:
275280
return value
276281

277-
if isinstance(_filter, str):
282+
if method == "GET":
278283
return "cql2-text"
279284

280-
if isinstance(_filter, dict):
285+
if method == "POST":
281286
return "cql2-json"
282287

283288
return None
284289

285-
def _format_filter(self, value: FilterLike | None) -> FilterLike | None:
286-
if value is None:
290+
def _format_filter(
291+
self,
292+
method: str | None,
293+
filter_lang: FilterLangLike | None,
294+
value: FilterLike | None,
295+
) -> FilterLike | None:
296+
if not value:
287297
return None
288298

289299
if self.client and not self.client.conforms_to(ConformanceClasses.FILTER):
290300
warnings.warn(DoesNotConformTo("FILTER"))
291301

302+
if method == "GET" and isinstance(value, str):
303+
return value
304+
305+
if method == "POST" and isinstance(value, dict):
306+
return value
307+
308+
# if filter_lang is specified, do not coerce
309+
if filter_lang is not None:
310+
return value
311+
312+
try:
313+
import cql2
314+
315+
if isinstance(value, dict):
316+
expr = cql2.parse_json(json.dumps(value))
317+
else:
318+
# could be cql2-text or stringified cql2-json
319+
expr = cql2.Expr(value)
320+
321+
except ImportError as e:
322+
raise ValueError(
323+
"Unless you specify ``filter_lang`` pystac-client will try to convert "
324+
"the filter to cql2-text or cql2-json based on the HTTP method "
325+
"provided.\n"
326+
"Resolve this error by installing ``cql2``: ``pip install cql2``"
327+
) from e
328+
329+
if method == "GET":
330+
return str(expr.to_text())
331+
332+
if method == "POST":
333+
return dict(expr.to_json())
334+
292335
return value
293336

294337
@staticmethod

tests/test_base_search.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,14 @@ def test_intersects_non_geo_interface_object(self) -> None:
330330
with pytest.raises(Exception):
331331
BaseSearch(url=SEARCH_URL, intersects=object()) # type: ignore
332332

333-
def test_filter_lang_default_for_dict(self) -> None:
334-
search = BaseSearch(url=SEARCH_URL, filter={})
335-
assert search.get_parameters()["filter-lang"] == "cql2-json"
336-
337-
def test_filter_lang_default_for_str(self) -> None:
338-
search = BaseSearch(url=SEARCH_URL, filter="")
333+
def test_filter_lang_default_for_method_despite_filter_as_dict(self) -> None:
334+
search = BaseSearch(url=SEARCH_URL, method="GET", filter={})
339335
assert search.get_parameters()["filter-lang"] == "cql2-text"
340336

337+
def test_filter_lang_default_for_method_despite_filter_as_str(self) -> None:
338+
search = BaseSearch(url=SEARCH_URL, method="POST", filter="")
339+
assert search.get_parameters()["filter-lang"] == "cql2-json"
340+
341341
def test_filter_lang_cql2_text(self) -> None:
342342
# Use specified filter_lang
343343
search = BaseSearch(url=SEARCH_URL, filter_lang="cql2-text", filter={})
@@ -353,6 +353,51 @@ def test_filter_lang_without_filter(self) -> None:
353353
search = BaseSearch(url=SEARCH_URL)
354354
assert "filter-lang" not in search.get_parameters()
355355

356+
def test_filter_conversion_to_cql2_json(self) -> None:
357+
search = BaseSearch(url=SEARCH_URL, method="POST", filter="eo:cloud_cover<=10")
358+
assert search.get_parameters()["filter-lang"] == "cql2-json"
359+
assert search.get_parameters()["filter"] == {
360+
"args": [{"property": "eo:cloud_cover"}, 10],
361+
"op": "<=",
362+
}
363+
364+
def test_filter_conversion_to_cql2_text(self) -> None:
365+
search = BaseSearch(
366+
url=SEARCH_URL,
367+
method="GET",
368+
filter={"op": "<=", "args": [{"property": "eo:cloud_cover"}, 10]},
369+
)
370+
assert search.get_parameters()["filter-lang"] == "cql2-text"
371+
assert search.get_parameters()["filter"] == '("eo:cloud_cover" <= 10)'
372+
373+
def test_filter_conversion_does_not_happen_if_filter_lang_specified_json(
374+
self,
375+
) -> None:
376+
search = BaseSearch(
377+
url=SEARCH_URL,
378+
method="GET",
379+
filter={"op": "<=", "args": [{"property": "eo:cloud_cover"}, 10]},
380+
filter_lang="cql2-json",
381+
)
382+
# assert search.get_parameters()["filter-lang"] == "cql2-json"
383+
assert (
384+
search.get_parameters()["filter"]
385+
== '{"op": "<=", "args": [{"property": "eo:cloud_cover"}, 10]}'
386+
)
387+
388+
def test_filter_conversion_does_not_happen_if_filter_lang_specified_text(
389+
self,
390+
) -> None:
391+
search = BaseSearch(
392+
url=SEARCH_URL,
393+
method="POST",
394+
filter="eo:cloud_cover<=10",
395+
filter_lang="cql2-text",
396+
)
397+
# note that this is likely to fail when it hits the server
398+
assert search.get_parameters()["filter-lang"] == "cql2-text"
399+
assert search.get_parameters()["filter"] == "eo:cloud_cover<=10"
400+
356401
def test_sortby(self) -> None:
357402
search = BaseSearch(url=SEARCH_URL, sortby="properties.datetime")
358403
assert search.get_parameters()["sortby"] == [

0 commit comments

Comments
 (0)