Skip to content

Commit ed67b77

Browse files
authored
Decode tile off the main thread (#40)
* Decode tile off the main thread * remove builder * pubcrate visibility
1 parent 937ec7b commit ed67b77

File tree

15 files changed

+178
-98
lines changed

15 files changed

+178
-98
lines changed

python/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ pyo3 = { version = "0.23.0", features = ["macros"] }
2525
pyo3-async-runtimes = "0.23"
2626
pyo3-bytes = "0.1.2"
2727
pyo3-object_store = { git = "https://github.com/developmentseed/obstore", rev = "28ba07a621c1c104f084fb47ae7f8d08b1eae3ea" }
28+
rayon = "1.10.0"
29+
tokio-rayon = "2.1.0"
2830
thiserror = "1"
2931

3032
# We opt-in to using rustls as the TLS provider for reqwest, which is the HTTP
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from ._decoder import Decoder as Decoder
2+
from ._decoder import DecoderRegistry as DecoderRegistry
13
from ._geo import GeoKeyDirectory as GeoKeyDirectory
24
from ._ifd import ImageFileDirectory as ImageFileDirectory
5+
from ._thread_pool import ThreadPool as ThreadPool
36
from ._tiff import TIFF as TIFF
7+
from ._tile import Tile as Tile

python/python/async_tiff/_decoder.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ class Decoder(Protocol):
1010
def __call__(buffer: Buffer) -> Buffer: ...
1111

1212
class DecoderRegistry:
13-
def __init__(self) -> None: ...
14-
def add(self, compression: CompressionMethod | int, decoder: Decoder) -> None: ...
13+
def __init__(
14+
self, decoders: dict[CompressionMethod | int, Decoder] | None = None
15+
) -> None: ...
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class ThreadPool:
2+
def __init__(self, num_threads: int) -> None: ...

python/python/async_tiff/_tile.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from collections.abc import Buffer
2+
3+
from ._decoder import DecoderRegistry
4+
from ._thread_pool import ThreadPool
5+
6+
class Tile:
7+
async def decode(
8+
self,
9+
*,
10+
decoder_registry: DecoderRegistry | None = None,
11+
pool: ThreadPool | None = None,
12+
) -> Buffer: ...

python/src/decoder.rs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,55 @@
1+
use std::collections::HashMap;
2+
use std::sync::Arc;
3+
14
use async_tiff::decoder::{Decoder, DecoderRegistry};
25
use async_tiff::error::AiocogeoError;
36
use async_tiff::tiff::tags::PhotometricInterpretation;
47
use bytes::Bytes;
58
use pyo3::exceptions::PyTypeError;
69
use pyo3::intern;
710
use pyo3::prelude::*;
11+
use pyo3::sync::GILOnceCell;
812
use pyo3::types::{PyDict, PyTuple};
913
use pyo3_bytes::PyBytes;
1014

1115
use crate::enums::PyCompressionMethod;
1216

13-
#[pyclass(name = "DecoderRegistry")]
14-
pub(crate) struct PyDecoderRegistry(DecoderRegistry);
17+
static DEFAULT_DECODER_REGISTRY: GILOnceCell<Arc<DecoderRegistry>> = GILOnceCell::new();
18+
19+
pub fn get_default_decoder_registry(py: Python<'_>) -> Arc<DecoderRegistry> {
20+
let registry =
21+
DEFAULT_DECODER_REGISTRY.get_or_init(py, || Arc::new(DecoderRegistry::default()));
22+
registry.clone()
23+
}
24+
25+
#[pyclass(name = "DecoderRegistry", frozen)]
26+
#[derive(Debug, Default)]
27+
pub(crate) struct PyDecoderRegistry(Arc<DecoderRegistry>);
1528

1629
#[pymethods]
1730
impl PyDecoderRegistry {
1831
#[new]
19-
fn new() -> Self {
20-
Self(DecoderRegistry::default())
32+
#[pyo3(signature = (decoders = None))]
33+
pub(crate) fn new(decoders: Option<HashMap<PyCompressionMethod, PyDecoder>>) -> Self {
34+
let mut decoder_registry = DecoderRegistry::default();
35+
if let Some(decoders) = decoders {
36+
for (compression, decoder) in decoders.into_iter() {
37+
decoder_registry
38+
.as_mut()
39+
.insert(compression.into(), Box::new(decoder));
40+
}
41+
}
42+
Self(Arc::new(decoder_registry))
2143
}
22-
23-
fn add(&mut self, compression: PyCompressionMethod, decoder: PyDecoder) {
24-
self.0
25-
.as_mut()
26-
.insert(compression.into(), Box::new(decoder));
44+
}
45+
impl PyDecoderRegistry {
46+
pub(crate) fn inner(&self) -> &Arc<DecoderRegistry> {
47+
&self.0
2748
}
2849
}
2950

3051
#[derive(Debug)]
31-
struct PyDecoder(PyObject);
52+
pub(crate) struct PyDecoder(PyObject);
3253

3354
impl PyDecoder {
3455
fn call(&self, py: Python, buffer: Bytes) -> PyResult<PyBytes> {

python/src/enums.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use pyo3::intern;
66
use pyo3::prelude::*;
77
use pyo3::types::{PyString, PyTuple};
88

9+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
910
pub(crate) struct PyCompressionMethod(CompressionMethod);
1011

1112
impl From<CompressionMethod> for PyCompressionMethod {

python/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@ mod decoder;
44
mod enums;
55
mod geo;
66
mod ifd;
7+
mod thread_pool;
78
mod tiff;
9+
mod tile;
810

911
use pyo3::prelude::*;
1012

1113
use crate::decoder::PyDecoderRegistry;
1214
use crate::geo::PyGeoKeyDirectory;
1315
use crate::ifd::PyImageFileDirectory;
16+
use crate::thread_pool::PyThreadPool;
1417
use crate::tiff::PyTIFF;
1518

1619
const VERSION: &str = env!("CARGO_PKG_VERSION");
@@ -48,6 +51,7 @@ fn _async_tiff(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
4851
m.add_class::<PyDecoderRegistry>()?;
4952
m.add_class::<PyGeoKeyDirectory>()?;
5053
m.add_class::<PyImageFileDirectory>()?;
54+
m.add_class::<PyThreadPool>()?;
5155
m.add_class::<PyTIFF>()?;
5256

5357
pyo3_object_store::register_store_module(py, m, "async_tiff")?;

python/src/thread_pool.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
use std::sync::Arc;
2+
3+
use pyo3::exceptions::PyValueError;
4+
use pyo3::prelude::*;
5+
6+
use pyo3::sync::GILOnceCell;
7+
use rayon::{ThreadPool, ThreadPoolBuilder};
8+
9+
static DEFAULT_POOL: GILOnceCell<Arc<ThreadPool>> = GILOnceCell::new();
10+
11+
pub fn get_default_pool(py: Python<'_>) -> PyResult<Arc<ThreadPool>> {
12+
let runtime = DEFAULT_POOL.get_or_try_init(py, || {
13+
let pool = ThreadPoolBuilder::new().build().map_err(|err| {
14+
PyValueError::new_err(format!("Could not create rayon threadpool. {}", err))
15+
})?;
16+
Ok::<_, PyErr>(Arc::new(pool))
17+
})?;
18+
Ok(runtime.clone())
19+
}
20+
21+
#[pyclass(name = "ThreadPool", frozen, module = "async_tiff")]
22+
pub(crate) struct PyThreadPool(Arc<ThreadPool>);
23+
24+
#[pymethods]
25+
impl PyThreadPool {
26+
#[new]
27+
fn new(num_threads: usize) -> PyResult<Self> {
28+
let pool = ThreadPoolBuilder::new()
29+
.num_threads(num_threads)
30+
.build()
31+
.map_err(|err| {
32+
PyValueError::new_err(format!("Could not create rayon threadpool. {}", err))
33+
})?;
34+
Ok(Self(Arc::new(pool)))
35+
}
36+
}
37+
38+
impl PyThreadPool {
39+
pub(crate) fn inner(&self) -> &Arc<ThreadPool> {
40+
&self.0
41+
}
42+
}
43+
44+
impl AsRef<ThreadPool> for PyThreadPool {
45+
fn as_ref(&self) -> &ThreadPool {
46+
&self.0
47+
}
48+
}

python/src/tile.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use async_tiff::Tile;
2+
use pyo3::prelude::*;
3+
use pyo3_async_runtimes::tokio::future_into_py;
4+
use pyo3_bytes::PyBytes;
5+
use tokio_rayon::AsyncThreadPool;
6+
7+
use crate::decoder::get_default_decoder_registry;
8+
use crate::thread_pool::{get_default_pool, PyThreadPool};
9+
use crate::PyDecoderRegistry;
10+
11+
#[pyclass(name = "Tile")]
12+
pub(crate) struct PyTile(Option<Tile>);
13+
14+
#[pymethods]
15+
impl PyTile {
16+
#[pyo3(signature = (*, decoder_registry=None, pool=None))]
17+
fn decode_async(
18+
&mut self,
19+
py: Python,
20+
decoder_registry: Option<&PyDecoderRegistry>,
21+
pool: Option<&PyThreadPool>,
22+
) -> PyResult<PyObject> {
23+
let decoder_registry = decoder_registry
24+
.map(|r| r.inner().clone())
25+
.unwrap_or_else(|| get_default_decoder_registry(py));
26+
let pool = pool
27+
.map(|p| Ok(p.inner().clone()))
28+
.unwrap_or_else(|| get_default_pool(py))?;
29+
let tile = self.0.take().unwrap();
30+
31+
let result = future_into_py(py, async move {
32+
let decoded_bytes = pool
33+
.spawn_async(move || tile.decode(&decoder_registry))
34+
.await
35+
.unwrap();
36+
Ok(PyBytes::new(decoded_bytes))
37+
})?;
38+
Ok(result.unbind())
39+
}
40+
}

0 commit comments

Comments
 (0)