Skip to content

Commit 07f89b7

Browse files
authored
feat: add store to search_to (#130)
1 parent 082a194 commit 07f89b7

File tree

4 files changed

+113
-92
lines changed

4 files changed

+113
-92
lines changed

Cargo.lock

Lines changed: 19 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/rustac/rustac.pyi

Lines changed: 56 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""The power of Rust for the Python STAC ecosystem."""
22

3+
from collections.abc import AsyncIterator
34
from pathlib import Path
4-
from typing import Any, AsyncIterator, Literal, Optional, Tuple
5+
from typing import Any, Literal
56

67
import arro3.core
78

@@ -45,18 +46,18 @@ class DuckdbClient:
4546
self,
4647
href: str,
4748
*,
48-
ids: Optional[str | list[str]] = None,
49-
collections: Optional[str | list[str]] = None,
50-
intersects: Optional[str | dict[str, Any]] = None,
51-
limit: Optional[int] = None,
52-
offset: Optional[int] = None,
53-
bbox: Optional[list[float]] = None,
54-
datetime: Optional[str] = None,
55-
include: Optional[str | list[str]] = None,
56-
exclude: Optional[str | list[str]] = None,
57-
sortby: Optional[str | list[str | dict[str, str]]] = None,
58-
filter: Optional[str | dict[str, Any]] = None,
59-
query: Optional[dict[str, Any]] = None,
49+
ids: str | list[str] | None = None,
50+
collections: str | list[str] | None = None,
51+
intersects: str | dict[str, Any] | None = None,
52+
limit: int | None = None,
53+
offset: int | None = None,
54+
bbox: list[float] | None = None,
55+
datetime: str | None = None,
56+
include: str | list[str] | None = None,
57+
exclude: str | list[str] | None = None,
58+
sortby: str | list[str | dict[str, str]] | None = None,
59+
filter: str | dict[str, Any] | None = None,
60+
query: dict[str, Any] | None = None,
6061
**kwargs: str,
6162
) -> list[dict[str, Any]]:
6263
"""Search a stac-geoparquet file with duckdb, returning a list of items.
@@ -94,18 +95,18 @@ class DuckdbClient:
9495
self,
9596
href: str,
9697
*,
97-
ids: Optional[str | list[str]] = None,
98-
collections: Optional[str | list[str]] = None,
99-
intersects: Optional[str | dict[str, Any]] = None,
100-
limit: Optional[int] = None,
101-
offset: Optional[int] = None,
102-
bbox: Optional[list[float]] = None,
103-
datetime: Optional[str] = None,
104-
include: Optional[str | list[str]] = None,
105-
exclude: Optional[str | list[str]] = None,
106-
sortby: Optional[str | list[str | dict[str, str]]] = None,
107-
filter: Optional[str | dict[str, Any]] = None,
108-
query: Optional[dict[str, Any]] = None,
98+
ids: str | list[str] | None = None,
99+
collections: str | list[str] | None = None,
100+
intersects: str | dict[str, Any] | None = None,
101+
limit: int | None = None,
102+
offset: int | None = None,
103+
bbox: list[float] | None = None,
104+
datetime: str | None = None,
105+
include: str | list[str] | None = None,
106+
exclude: str | list[str] | None = None,
107+
sortby: str | list[str | dict[str, str]] | None = None,
108+
filter: str | dict[str, Any] | None = None,
109+
query: dict[str, Any] | None = None,
109110
**kwargs: str,
110111
) -> arro3.core.Table | None:
111112
"""Search a stac-geoparquet file with duckdb, returning an arrow table
@@ -179,7 +180,7 @@ def collection_from_id_and_items(id: str, items: list[Item]) -> Collection:
179180
A STAC collection
180181
"""
181182

182-
def migrate(value: dict[str, Any], version: Optional[str] = None) -> dict[str, Any]:
183+
def migrate(value: dict[str, Any], version: str | None = None) -> dict[str, Any]:
183184
"""
184185
Migrates a STAC dictionary to another version.
185186
@@ -264,19 +265,19 @@ def to_arrow(
264265
async def search(
265266
href: str,
266267
*,
267-
intersects: Optional[str | dict[str, Any]] = None,
268-
ids: Optional[str | list[str]] = None,
269-
collections: Optional[str | list[str]] = None,
270-
max_items: Optional[int] = None,
271-
limit: Optional[int] = None,
272-
bbox: Optional[list[float]] = None,
273-
datetime: Optional[str] = None,
274-
include: Optional[str | list[str]] = None,
275-
exclude: Optional[str | list[str]] = None,
276-
sortby: Optional[str | list[str | dict[str, str]]] = None,
277-
filter: Optional[str | dict[str, Any]] = None,
278-
query: Optional[dict[str, Any]] = None,
279-
use_duckdb: Optional[bool] = None,
268+
intersects: str | dict[str, Any] | None = None,
269+
ids: str | list[str] | None = None,
270+
collections: str | list[str] | None = None,
271+
max_items: int | None = None,
272+
limit: int | None = None,
273+
bbox: list[float] | None = None,
274+
datetime: str | None = None,
275+
include: str | list[str] | None = None,
276+
exclude: str | list[str] | None = None,
277+
sortby: str | list[str | dict[str, str]] | None = None,
278+
filter: str | dict[str, Any] | None = None,
279+
query: dict[str, Any] | None = None,
280+
use_duckdb: bool | None = None,
280281
**kwargs: str,
281282
) -> list[dict[str, Any]]:
282283
"""
@@ -333,21 +334,21 @@ async def search_to(
333334
outfile: str,
334335
href: str,
335336
*,
336-
intersects: Optional[str | dict[str, Any]] = None,
337-
ids: Optional[str | list[str]] = None,
338-
collections: Optional[str | list[str]] = None,
339-
max_items: Optional[int] = None,
340-
limit: Optional[int] = None,
341-
bbox: Optional[list[float]] = None,
342-
datetime: Optional[str] = None,
343-
include: Optional[str | list[str]] = None,
344-
exclude: Optional[str | list[str]] = None,
345-
sortby: Optional[str | list[str | dict[str, str]]] = None,
346-
filter: Optional[str | dict[str, Any]] = None,
347-
query: Optional[dict[str, Any]] = None,
348-
format: Optional[str] = None,
349-
options: Optional[list[Tuple[str, str]]] = None,
350-
use_duckdb: Optional[bool] = None,
337+
intersects: str | dict[str, Any] | None = None,
338+
ids: str | list[str] | None = None,
339+
collections: str | list[str] | None = None,
340+
max_items: int | None = None,
341+
limit: int | None = None,
342+
bbox: list[float] | None = None,
343+
datetime: str | None = None,
344+
include: str | list[str] | None = None,
345+
exclude: str | list[str] | None = None,
346+
sortby: str | list[str | dict[str, str]] | None = None,
347+
filter: str | dict[str, Any] | None = None,
348+
query: dict[str, Any] | None = None,
349+
format: str | None = None,
350+
store: ObjectStore | None = None,
351+
use_duckdb: bool | None = None,
351352
) -> int:
352353
"""
353354
Searches a STAC API server and saves the result to an output file.
@@ -385,7 +386,7 @@ async def search_to(
385386
It is recommended to use filter instead, if possible.
386387
format: The output format. If none, will be inferred from
387388
the outfile extension, and if that fails will fall back to compact JSON.
388-
options: Configuration values to pass to the object store backend.
389+
store: An optional [ObjectStore][]
389390
use_duckdb: Query with DuckDB. If None and the href has a
390391
'parquet' or 'geoparquet' extension, will be set to True. Defaults
391392
to None.

src/search.rs

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::{Error, Json, Result};
22
use geojson::Geometry;
33
use pyo3::prelude::*;
44
use pyo3::{Bound, FromPyObject, PyErr, PyResult, exceptions::PyValueError, types::PyDict};
5+
use pyo3_object_store::AnyObjectStore;
56
use stac::Bbox;
67
use stac::Format;
78
use stac_api::{Fields, Filter, Items, Search, Sortby};
@@ -57,7 +58,7 @@ pub fn search<'py>(
5758
}
5859

5960
#[pyfunction]
60-
#[pyo3(signature = (outfile, href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, format=None, options=None, use_duckdb=None, **kwargs))]
61+
#[pyo3(signature = (outfile, href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, format=None, store=None, use_duckdb=None, **kwargs))]
6162
#[allow(clippy::too_many_arguments)]
6263
pub fn search_to<'py>(
6364
py: Python<'py>,
@@ -76,7 +77,7 @@ pub fn search_to<'py>(
7677
filter: Option<StringOrDict>,
7778
query: Option<Bound<'py, PyDict>>,
7879
format: Option<String>,
79-
options: Option<Vec<(String, String)>>,
80+
store: Option<AnyObjectStore>,
8081
use_duckdb: Option<bool>,
8182
kwargs: Option<Bound<'_, PyDict>>,
8283
) -> PyResult<Bound<'py, PyAny>> {
@@ -106,28 +107,36 @@ pub fn search_to<'py>(
106107
pyo3_async_runtimes::tokio::future_into_py(py, async move {
107108
let value = search_duckdb(href, search, max_items)?;
108109
let count = value.items.len();
109-
let _ = format
110-
.put_opts(
111-
outfile,
112-
serde_json::to_value(value).map_err(Error::from)?,
113-
options.unwrap_or_default(),
114-
)
115-
.await
116-
.map_err(Error::from)?;
110+
let value = serde_json::to_value(value).map_err(Error::from)?;
111+
if let Some(store) = store {
112+
format
113+
.put_store(store.into_dyn(), outfile, value)
114+
.await
115+
.map_err(Error::from)?;
116+
} else {
117+
format
118+
.put_opts(outfile, value, [] as [(&str, &str); 0])
119+
.await
120+
.map_err(Error::from)?;
121+
}
117122
Ok(count)
118123
})
119124
} else {
120125
pyo3_async_runtimes::tokio::future_into_py(py, async move {
121126
let value = search_api(href, search, max_items).await?;
122127
let count = value.items.len();
123-
let _ = format
124-
.put_opts(
125-
outfile,
126-
serde_json::to_value(value).map_err(Error::from)?,
127-
options.unwrap_or_default(),
128-
)
129-
.await
130-
.map_err(Error::from)?;
128+
let value = serde_json::to_value(value).map_err(Error::from)?;
129+
if let Some(store) = store {
130+
format
131+
.put_store(store.into_dyn(), outfile, value)
132+
.await
133+
.map_err(Error::from)?;
134+
} else {
135+
format
136+
.put_opts(outfile, value, [] as [(&str, &str); 0])
137+
.await
138+
.map_err(Error::from)?;
139+
}
131140
Ok(count)
132141
})
133142
}

tests/test_search.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pyarrow.parquet
66
import rustac
77
import stac_geoparquet.arrow
8+
from rustac.store import MemoryStore
89

910

1011
async def test_search() -> None:
@@ -64,3 +65,13 @@ async def test_sortby_list_of_dict() -> None:
6465

6566
async def test_proj_geometry(maxar_items: list[dict[str, Any]], tmp_path: Path) -> None:
6667
await rustac.write(str(tmp_path / "out.parquet"), maxar_items)
68+
69+
70+
async def test_search_to_store(data: Path) -> None:
71+
store = MemoryStore()
72+
count = await rustac.search_to(
73+
"items.json", str(data / "100-sentinel-2-items.parquet"), store=store
74+
)
75+
assert count == 100
76+
item_collection = await rustac.read("items.json", store=store)
77+
assert len(item_collection["features"]) == 100

0 commit comments

Comments
 (0)