Skip to content

Custom decoder registry #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/python/async_tiff/_decoder.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Protocol
from collections.abc import Buffer

from .enums import CompressionMethod

class Decoder(Protocol):
@staticmethod
def __call__(buffer: Buffer) -> Buffer: ...

class DecoderRegistry:
def __init__(self) -> None: ...
def add(self, compression: CompressionMethod | int, decoder: Decoder) -> None: ...
63 changes: 63 additions & 0 deletions python/src/decoder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use async_tiff::decoder::{Decoder, DecoderRegistry};
use async_tiff::error::AiocogeoError;
use bytes::Bytes;
use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple};
use pyo3_bytes::PyBytes;

use crate::enums::PyCompressionMethod;

#[pyclass(name = "DecoderRegistry")]
pub(crate) struct PyDecoderRegistry(DecoderRegistry);

#[pymethods]
impl PyDecoderRegistry {
#[new]
fn new() -> Self {
Self(DecoderRegistry::default())
}

fn add(&mut self, compression: PyCompressionMethod, decoder: PyDecoder) {
self.0
.as_mut()
.insert(compression.into(), Box::new(decoder));
}
}

#[derive(Debug)]
struct PyDecoder(PyObject);

impl PyDecoder {
fn call(&self, py: Python, buffer: Bytes) -> PyResult<PyBytes> {
let kwargs = PyDict::new(py);
kwargs.set_item(intern!(py, "buffer"), PyBytes::new(buffer))?;
let result = self.0.call(py, PyTuple::empty(py), Some(&kwargs))?;
result.extract(py)
}
}

impl<'py> FromPyObject<'py> for PyDecoder {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
if !ob.hasattr(intern!(ob.py(), "__call__"))? {
return Err(PyTypeError::new_err(
"Expected callable object for custom decoder.",
));
}
Ok(Self(ob.clone().unbind()))
}
}

impl Decoder for PyDecoder {
fn decode_tile(
&self,
buffer: bytes::Bytes,
_photometric_interpretation: tiff::tags::PhotometricInterpretation,
_jpeg_tables: Option<&[u8]>,
) -> async_tiff::error::Result<bytes::Bytes> {
let decoded_buffer = Python::with_gil(|py| self.call(py, buffer))
.map_err(|err| AiocogeoError::General(err.to_string()))?;
Ok(decoded_buffer.into_inner())
}
}
12 changes: 12 additions & 0 deletions python/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ impl From<CompressionMethod> for PyCompressionMethod {
}
}

impl From<PyCompressionMethod> for CompressionMethod {
fn from(value: PyCompressionMethod) -> Self {
value.0
}
}

impl<'py> FromPyObject<'py> for PyCompressionMethod {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
Ok(Self(CompressionMethod::from_u16_exhaustive(ob.extract()?)))
}
}

impl<'py> IntoPyObject<'py> for PyCompressionMethod {
type Target = PyAny;
type Output = Bound<'py, PyAny>;
Expand Down
3 changes: 3 additions & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#![deny(clippy::undocumented_unsafe_blocks)]

mod decoder;
mod enums;
mod geo;
mod ifd;
mod tiff;

use pyo3::prelude::*;

use crate::decoder::PyDecoderRegistry;
use crate::geo::PyGeoKeyDirectory;
use crate::ifd::PyImageFileDirectory;
use crate::tiff::PyTIFF;
Expand Down Expand Up @@ -43,6 +45,7 @@ fn _async_tiff(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
check_debug_build(py)?;

m.add_wrapped(wrap_pyfunction!(___version))?;
m.add_class::<PyDecoderRegistry>()?;
m.add_class::<PyGeoKeyDirectory>()?;
m.add_class::<PyImageFileDirectory>()?;
m.add_class::<PyTIFF>()?;
Expand Down
7 changes: 6 additions & 1 deletion src/cog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mod test {
use std::io::BufReader;
use std::sync::Arc;

use crate::decoder::DecoderRegistry;
use crate::ObjectReader;

use super::*;
Expand All @@ -66,7 +67,11 @@ mod test {
let cog_reader = COGReader::try_open(Box::new(reader.clone())).await.unwrap();

let ifd = &cog_reader.ifds.as_ref()[1];
let tile = ifd.get_tile(0, 0, Box::new(reader)).await.unwrap();
let decoder_registry = DecoderRegistry::default();
let tile = ifd
.get_tile(0, 0, Box::new(reader), &decoder_registry)
.await
.unwrap();
std::fs::write("img.buf", tile).unwrap();
}

Expand Down
157 changes: 123 additions & 34 deletions src/decoder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;
use std::fmt::Debug;
use std::io::{Cursor, Read};

use bytes::Bytes;
Expand All @@ -7,47 +9,138 @@ use tiff::{TiffError, TiffUnsupportedError};

use crate::error::Result;

/// A registry of decoders.
#[derive(Debug)]
pub struct DecoderRegistry(HashMap<CompressionMethod, Box<dyn Decoder>>);

impl DecoderRegistry {
/// Create a new decoder registry with no decoders registered
pub fn new() -> Self {
Self(HashMap::new())
}
}

impl AsRef<HashMap<CompressionMethod, Box<dyn Decoder>>> for DecoderRegistry {
fn as_ref(&self) -> &HashMap<CompressionMethod, Box<dyn Decoder>> {
&self.0
}
}

impl AsMut<HashMap<CompressionMethod, Box<dyn Decoder>>> for DecoderRegistry {
fn as_mut(&mut self) -> &mut HashMap<CompressionMethod, Box<dyn Decoder>> {
&mut self.0
}
}

impl Default for DecoderRegistry {
fn default() -> Self {
let mut registry = HashMap::with_capacity(5);
registry.insert(CompressionMethod::None, Box::new(UncompressedDecoder) as _);
registry.insert(CompressionMethod::Deflate, Box::new(DeflateDecoder) as _);
registry.insert(CompressionMethod::OldDeflate, Box::new(DeflateDecoder) as _);
registry.insert(CompressionMethod::LZW, Box::new(LZWDecoder) as _);
registry.insert(CompressionMethod::ModernJPEG, Box::new(JPEGDecoder) as _);
Self(registry)
}
}

/// A trait to decode a TIFF tile.
pub trait Decoder: Debug + Send + Sync {
fn decode_tile(
&self,
buffer: Bytes,
photometric_interpretation: PhotometricInterpretation,
jpeg_tables: Option<&[u8]>,
) -> Result<Bytes>;
}

#[derive(Debug, Clone)]
pub struct DeflateDecoder;

impl Decoder for DeflateDecoder {
fn decode_tile(
&self,
buffer: Bytes,
_photometric_interpretation: PhotometricInterpretation,
_jpeg_tables: Option<&[u8]>,
) -> Result<Bytes> {
let mut decoder = ZlibDecoder::new(Cursor::new(buffer));
let mut buf = Vec::new();
decoder.read_to_end(&mut buf)?;
Ok(buf.into())
}
}

#[derive(Debug, Clone)]
pub struct JPEGDecoder;

impl Decoder for JPEGDecoder {
fn decode_tile(
&self,
buffer: Bytes,
photometric_interpretation: PhotometricInterpretation,
jpeg_tables: Option<&[u8]>,
) -> Result<Bytes> {
decode_modern_jpeg(buffer, photometric_interpretation, jpeg_tables)
}
}

#[derive(Debug, Clone)]
pub struct LZWDecoder;

impl Decoder for LZWDecoder {
fn decode_tile(
&self,
buffer: Bytes,
_photometric_interpretation: PhotometricInterpretation,
_jpeg_tables: Option<&[u8]>,
) -> Result<Bytes> {
// https://github.com/image-rs/image-tiff/blob/90ae5b8e54356a35e266fb24e969aafbcb26e990/src/decoder/stream.rs#L147
let mut decoder = weezl::decode::Decoder::with_tiff_size_switch(weezl::BitOrder::Msb, 8);
let decoded = decoder.decode(&buffer).expect("failed to decode LZW data");
Ok(decoded.into())
}
}

#[derive(Debug, Clone)]
pub struct UncompressedDecoder;

impl Decoder for UncompressedDecoder {
fn decode_tile(
&self,
buffer: Bytes,
_photometric_interpretation: PhotometricInterpretation,
_jpeg_tables: Option<&[u8]>,
) -> Result<Bytes> {
Ok(buffer)
}
}

// https://github.com/image-rs/image-tiff/blob/3bfb43e83e31b0da476832067ada68a82b378b7b/src/decoder/image.rs#L370
pub(crate) fn decode_tile(
buf: Bytes,
photometric_interpretation: PhotometricInterpretation,
compression_method: CompressionMethod,
// compressed_length: u64,
jpeg_tables: Option<&Vec<u8>>,
jpeg_tables: Option<&[u8]>,
decoder_registry: &DecoderRegistry,
) -> Result<Bytes> {
match compression_method {
CompressionMethod::None => Ok(buf),
CompressionMethod::LZW => decode_lzw(buf),
CompressionMethod::Deflate | CompressionMethod::OldDeflate => decode_deflate(buf),
CompressionMethod::ModernJPEG => {
decode_modern_jpeg(buf, photometric_interpretation, jpeg_tables)
}
method => Err(TiffError::UnsupportedError(
TiffUnsupportedError::UnsupportedCompressionMethod(method),
)
.into()),
}
}

fn decode_lzw(buf: Bytes) -> Result<Bytes> {
// https://github.com/image-rs/image-tiff/blob/90ae5b8e54356a35e266fb24e969aafbcb26e990/src/decoder/stream.rs#L147
let mut decoder = weezl::decode::Decoder::with_tiff_size_switch(weezl::BitOrder::Msb, 8);
let decoded = decoder.decode(&buf).expect("failed to decode LZW data");
Ok(decoded.into())
}
let decoder =
decoder_registry
.0
.get(&compression_method)
.ok_or(TiffError::UnsupportedError(
TiffUnsupportedError::UnsupportedCompressionMethod(compression_method),
))?;

fn decode_deflate(buf: Bytes) -> Result<Bytes> {
let mut decoder = ZlibDecoder::new(Cursor::new(buf));
let mut buf = Vec::new();
decoder.read_to_end(&mut buf)?;
Ok(buf.into())
decoder.decode_tile(buf, photometric_interpretation, jpeg_tables)
}

// https://github.com/image-rs/image-tiff/blob/3bfb43e83e31b0da476832067ada68a82b378b7b/src/decoder/image.rs#L389-L450
fn decode_modern_jpeg(
buf: Bytes,
photometric_interpretation: PhotometricInterpretation,
jpeg_tables: Option<&Vec<u8>>,
jpeg_tables: Option<&[u8]>,
) -> Result<Bytes> {
// Construct new jpeg_reader wrapping a SmartReader.
//
Expand Down Expand Up @@ -76,13 +169,9 @@ fn decode_modern_jpeg(

match photometric_interpretation {
PhotometricInterpretation::RGB => decoder.set_color_transform(jpeg::ColorTransform::RGB),
PhotometricInterpretation::WhiteIsZero => {
decoder.set_color_transform(jpeg::ColorTransform::None)
}
PhotometricInterpretation::BlackIsZero => {
decoder.set_color_transform(jpeg::ColorTransform::None)
}
PhotometricInterpretation::TransparencyMask => {
PhotometricInterpretation::WhiteIsZero
| PhotometricInterpretation::BlackIsZero
| PhotometricInterpretation::TransparencyMask => {
decoder.set_color_transform(jpeg::ColorTransform::None)
}
PhotometricInterpretation::CMYK => decoder.set_color_transform(jpeg::ColorTransform::CMYK),
Expand Down
10 changes: 7 additions & 3 deletions src/ifd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tiff::tags::{
use tiff::TiffError;

use crate::async_reader::AsyncCursor;
use crate::decoder::decode_tile;
use crate::decoder::{decode_tile, DecoderRegistry};
use crate::error::{AiocogeoError, Result};
use crate::geo::{AffineTransform, GeoKeyDirectory, GeoKeyTag};
use crate::AsyncFileReader;
Expand Down Expand Up @@ -681,14 +681,16 @@ impl ImageFileDirectory {
x: usize,
y: usize,
mut reader: Box<dyn AsyncFileReader>,
decoder_registry: &DecoderRegistry,
) -> Result<Bytes> {
let range = self.get_tile_byte_range(x, y);
let buf = reader.get_bytes(range).await?;
decode_tile(
buf,
self.photometric_interpretation,
self.compression,
self.jpeg_tables.as_ref(),
self.jpeg_tables.as_deref(),
decoder_registry,
)
}

Expand All @@ -697,6 +699,7 @@ impl ImageFileDirectory {
x: &[usize],
y: &[usize],
mut reader: Box<dyn AsyncFileReader>,
decoder_registry: &DecoderRegistry,
) -> Result<Vec<Bytes>> {
assert_eq!(x.len(), y.len(), "x and y should have same len");

Expand All @@ -717,7 +720,8 @@ impl ImageFileDirectory {
buf,
self.photometric_interpretation,
self.compression,
self.jpeg_tables.as_ref(),
self.jpeg_tables.as_deref(),
decoder_registry,
)?;
decoded_tiles.push(decoded);
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

mod async_reader;
mod cog;
mod decoder;
pub mod decoder;
pub mod error;
pub mod geo;
mod ifd;
Expand Down