diff --git a/pyproject.toml b/pyproject.toml index 669ea94..c316478 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,9 @@ schema = [ "fastjsonschema", "importlib-resources; python_version<'3.9'", ] +hdf5 = [ + "h5py", +] [dependency-groups] docs = [ @@ -62,6 +65,7 @@ test = [ "boost-histogram>=1.0", "fastjsonschema", "importlib-resources; python_version<'3.9'", + "h5py; platform_python_implementation == 'cpython'", ] dev = [{ include-group = "test"}] @@ -89,7 +93,7 @@ warn_unreachable = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] [[tool.mypy.overrides]] -module = ["fastjsonschema"] +module = ["fastjsonschema", "h5py"] ignore_missing_imports = true diff --git a/src/uhi/io/__init__.py b/src/uhi/io/__init__.py new file mode 100644 index 0000000..0f48e9d --- /dev/null +++ b/src/uhi/io/__init__.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +__all__ = ["ARRAY_KEYS", "LIST_KEYS"] + +ARRAY_KEYS = frozenset( + [ + "values", + "variances", + "edges", + "counts", + "sum_of_weights", + "sum_of_weights_squared", + ] +) + +LIST_KEYS = frozenset( + [ + "categories", + ] +) diff --git a/src/uhi/io/hdf5.py b/src/uhi/io/hdf5.py new file mode 100644 index 0000000..73bcf68 --- /dev/null +++ b/src/uhi/io/hdf5.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from typing import Any + +import h5py +import numpy as np + +from ..typing.serialization import AnyAxis, AnyHistogram, AnyStorage, Histogram +from . import ARRAY_KEYS + +__all__ = ["read", "write"] + + +def __dir__() -> list[str]: + return __all__ + + +def write(grp: h5py.Group, /, histogram: AnyHistogram) -> None: + """ + Write a histogram to an HDF5 group. + """ + # All referenced objects will be stored inside of /{name}/ref_axes + hist_folder_storage = grp.create_group("ref_axes") + + # Metadata + + if "metadata" in histogram: + metadata_grp = grp.create_group("metadata") + for key, val1 in histogram["metadata"].items(): + metadata_grp.attrs[key] = val1 + + # Axes + axes_dataset = grp.create_dataset( + "axes", len(histogram["axes"]), dtype=h5py.special_dtype(ref=h5py.Reference) + ) + for i, axis in enumerate(histogram["axes"]): + # Iterating through the axes, calling `create_axes_object` for each of them, + # creating references to new groups and appending it to the `items` dataset defined above + ax_group = hist_folder_storage.create_group(f"axis_{i}") + ax_info = axis.copy() + ax_metadata = ax_info.pop("metadata", None) + ax_edges_raw = ax_info.pop("edges", None) + ax_edges = np.asarray(ax_edges_raw) if ax_edges_raw is not None else None + ax_cats: list[int] | list[str] | None = ax_info.pop("categories", None) + for key, val2 in ax_info.items(): + ax_group.attrs[key] = val2 + if ax_metadata is not None: + ax_metadata_grp = ax_group.create_group("metadata") + for k, v in ax_metadata.items(): + ax_metadata_grp.attrs[k] = v + if ax_edges is not None: + ax_group.create_dataset("edges", shape=ax_edges.shape, data=ax_edges) + if ax_cats is not None: + ax_group.create_dataset("categories", shape=len(ax_cats), data=ax_cats) + axes_dataset[i] = ax_group.ref + + # Storage + storage_grp = grp.create_group("storage") + storage_type = histogram["storage"]["type"] + + storage_grp.attrs["type"] = storage_type + + for key, val3 in histogram["storage"].items(): + if key == "type": + continue + npvalue = np.asarray(val3) + storage_grp.create_dataset(key, shape=npvalue.shape, data=npvalue) + + +def _convert_axes(group: h5py.Group | h5py.Dataset | h5py.Datatype) -> AnyAxis: + """ + Convert an HDF5 axis reference to a dictionary. + """ + assert isinstance(group, h5py.Group) + + axis = {k: _convert_item(k, v) for k, v in group.attrs.items()} + if "edges" in group: + edges = group["edges"] + assert isinstance(edges, h5py.Dataset) + axis["edges"] = np.asarray(edges) + if "categories" in group: + categories = group["categories"] + assert isinstance(categories, h5py.Dataset) + axis["categories"] = [_convert_item("", c) for c in categories] + + return axis # type: ignore[return-value] + + +def _convert_item(name: str, item: Any, /) -> Any: + """ + Convert an HDF5 item to a native Python type. + """ + if isinstance(item, bytes): + return item.decode("utf-8") + if name == "metadata": + return {k: _convert_item("", v) for k, v in item.items()} + if name in ARRAY_KEYS: + return item + if isinstance(item, np.generic): + return item.item() + return item + + +def read(grp: h5py.Group, /) -> Histogram: + """ + Read a histogram from an HDF5 group. + """ + axes_grp = grp["axes"] + axes_ref = grp["ref_axes"] + assert isinstance(axes_ref, h5py.Group) + assert isinstance(axes_grp, h5py.Dataset) + + axes = [_convert_axes(axes_ref[unref_axis_ref]) for unref_axis_ref in axes_ref] + + storage_grp = grp["storage"] + assert isinstance(storage_grp, h5py.Group) + storage = AnyStorage(type=storage_grp.attrs["type"]) + for key in storage_grp: + storage[key] = np.asarray(storage_grp[key]) # type: ignore[literal-required] + + histogram_dict = AnyHistogram(axes=axes, storage=storage) + if "metadata" in grp: + histogram_dict["metadata"] = _convert_item("metadata", grp["metadata"].attrs) + + return histogram_dict # type: ignore[return-value] diff --git a/src/uhi/io/json.py b/src/uhi/io/json.py new file mode 100644 index 0000000..b844f82 --- /dev/null +++ b/src/uhi/io/json.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from . import ARRAY_KEYS + +__all__ = ["default", "object_hook"] + + +def __dir__() -> list[str]: + return __all__ + + +def default(obj: Any, /) -> Any: + if isinstance(obj, np.ndarray): + return obj.tolist() # Convert ndarray to list + msg = f"Object of type {type(obj)} is not JSON serializable" + raise TypeError(msg) + + +def object_hook(dct: dict[str, Any], /) -> dict[str, Any]: + """ + Decode a histogram from a dictionary. + """ + + for item in ARRAY_KEYS & dct.keys(): + if isinstance(dct[item], list): + dct[item] = np.asarray(dct[item]) + + return dct diff --git a/src/uhi/io/zip.py b/src/uhi/io/zip.py new file mode 100644 index 0000000..8bdf2e0 --- /dev/null +++ b/src/uhi/io/zip.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json +import zipfile +from typing import Any + +import numpy as np + +from ..typing.serialization import AnyHistogram, Histogram +from . import ARRAY_KEYS + +__all__ = ["read", "write"] + + +def __dir__() -> list[str]: + return __all__ + + +def write( + zip_file: zipfile.ZipFile, + /, + name: str, + histogram: AnyHistogram, +) -> None: + """ + Write a histogram to a zip file. + """ + # Write out numpy arrays to files in the zipfile + for storage_key in ARRAY_KEYS & histogram["storage"].keys(): + path = f"{name}_storage_{storage_key}.npy" + with zip_file.open(path, "w") as f: + np.save(f, histogram["storage"][storage_key]) # type: ignore[literal-required] + histogram["storage"][storage_key] = path # type: ignore[literal-required] + + for axis in histogram["axes"]: + for key in ARRAY_KEYS & axis.keys(): + path = f"{name}_axis_{key}.npy" + with zip_file.open(path, "w") as f: + np.save(f, axis[key]) # type: ignore[literal-required] + axis[key] = path # type: ignore[literal-required] + + hist_json = json.dumps(histogram) + zip_file.writestr(f"{name}.json", hist_json) + + +def read(zip_file: zipfile.ZipFile, /, name: str) -> Histogram: + """ + Read histograms from a zip file. + """ + + def object_hook(dct: dict[str, Any], /) -> dict[str, Any]: + for item in ARRAY_KEYS & dct.keys(): + if isinstance(dct[item], str): + dct[item] = np.load(zip_file.open(dct[item])) + return dct + + with zip_file.open(f"{name}.json") as f: + return json.load(f, object_hook=object_hook) # type: ignore[no-any-return] diff --git a/src/uhi/resources/histogram.schema.json b/src/uhi/resources/histogram.schema.json index 191e09b..5b4c100 100644 --- a/src/uhi/resources/histogram.schema.json +++ b/src/uhi/resources/histogram.schema.json @@ -52,6 +52,13 @@ } } }, + "ndarray": { + "type": "array", + "items": { + "oneOf": [{ "type": "number" }, { "$ref": "#/$defs/ndarray" }] + }, + "description": "A ND (nested) array of numbers." + }, "data_array": { "oneOf": [ { @@ -59,8 +66,7 @@ "description": "A path (similar to URI) to the floating point bin data" }, { - "type": "array", - "items": { "type": "number" } + "$ref": "#/$defs/ndarray" } ] }, diff --git a/src/uhi/typing/serialization.py b/src/uhi/typing/serialization.py index fc74d9a..9fdd61e 100644 --- a/src/uhi/typing/serialization.py +++ b/src/uhi/typing/serialization.py @@ -1,9 +1,24 @@ +"""Serialization types for UHI. + +Two types of dictionaries are defined here: + +1. ``AnyAxis``, ``AnyStorage``, and ``AnyHistogram`` are used for inputs. They represent + the merger of all possible types. +2. ``Axis``, ``Storage``, and ``histogram`` are used for outputs. These have precise entries + defined for each Literal type. +""" + from __future__ import annotations -from collections.abc import Sequence from typing import Literal, TypedDict, Union +from numpy.typing import ArrayLike + __all__ = [ + "AnyAxis", + "AnyHistogram", + "AnyStorage", + "Axis", "BooleanAxis", "CategoryIntAxis", "CategoryStrAxis", @@ -12,6 +27,7 @@ "IntStorage", "MeanStorage", "RegularAxis", + "Storage", "VariableAxis", "WeightedMeanStorage", "WeightedStorage", @@ -40,7 +56,7 @@ class RegularAxis(_RequiredRegularAxis, total=False): class _RequiredVariableAxis(TypedDict): type: Literal["variable"] - edges: list[float] | str + edges: ArrayLike | str underflow: bool overflow: bool circular: bool @@ -80,43 +96,84 @@ class BooleanAxis(_RequiredBooleanAxis, total=False): class IntStorage(TypedDict): type: Literal["int"] - values: Sequence[int] | str + values: ArrayLike | str class DoubleStorage(TypedDict): type: Literal["double"] - values: Sequence[float] | str + values: ArrayLike | str class WeightedStorage(TypedDict): type: Literal["weighted"] - values: Sequence[float] | str - variances: Sequence[float] | str + values: ArrayLike | str + variances: ArrayLike | str class MeanStorage(TypedDict): type: Literal["mean"] - counts: Sequence[float] | str - values: Sequence[float] | str - variances: Sequence[float] | str + counts: ArrayLike | str + values: ArrayLike | str + variances: ArrayLike | str class WeightedMeanStorage(TypedDict): type: Literal["weighted_mean"] - sum_of_weights: Sequence[float] | str - sum_of_weights_squared: Sequence[float] | str - values: Sequence[float] | str - variances: Sequence[float] | str + sum_of_weights: ArrayLike | str + sum_of_weights_squared: ArrayLike | str + values: ArrayLike | str + variances: ArrayLike | str + + +Storage = Union[ + IntStorage, DoubleStorage, WeightedStorage, MeanStorage, WeightedMeanStorage +] + +Axis = Union[RegularAxis, VariableAxis, CategoryStrAxis, CategoryIntAxis, BooleanAxis] + + +class _RequiredAnyStorage(TypedDict): + type: Literal["int", "double", "weighted", "mean", "weighted_mean"] + + +class AnyStorage(_RequiredAnyStorage, total=False): + values: ArrayLike | str + variances: ArrayLike | str + sum_of_weights: ArrayLike | str + sum_of_weights_squared: ArrayLike | str + counts: ArrayLike | str + + +class _RequiredAnyAxis(TypedDict): + type: Literal["regular", "variable", "category_str", "category_int", "boolean"] + + +class AnyAxis(_RequiredAnyAxis, total=False): + metadata: dict[str, SupportedMetadata] + lower: float + upper: float + bins: int + edges: ArrayLike | str + categories: list[str] | list[int] + underflow: bool + overflow: bool + flow: bool + circular: bool class _RequiredHistogram(TypedDict): - axes: list[ - RegularAxis | VariableAxis | CategoryStrAxis | CategoryIntAxis | BooleanAxis - ] - storage: ( - IntStorage | DoubleStorage | WeightedStorage | MeanStorage | WeightedMeanStorage - ) + axes: list[Axis] + storage: Storage class Histogram(_RequiredHistogram, total=False): metadata: dict[str, SupportedMetadata] + + +class _RequiredAnyHistogram(TypedDict): + axes: list[AnyAxis] + storage: AnyStorage + + +class AnyHistogram(_RequiredAnyHistogram, total=False): + metadata: dict[str, SupportedMetadata] diff --git a/tests/resources/valid/2d.json b/tests/resources/valid/2d.json new file mode 100644 index 0000000..e083e5d --- /dev/null +++ b/tests/resources/valid/2d.json @@ -0,0 +1,27 @@ +{ + "main": { + "axes": [ + { + "type": "variable", + "edges": [1, 2, 3, 4], + "underflow": true, + "overflow": false, + "circular": false + }, + { + "type": "category_str", + "categories": ["a", "b", "c"], + "flow": false + } + ], + "storage": { + "type": "double", + "values": [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0] + ] + } + } +} diff --git a/tests/resources/valid/reg.json b/tests/resources/valid/reg.json index 754fc05..48adca6 100644 --- a/tests/resources/valid/reg.json +++ b/tests/resources/valid/reg.json @@ -26,6 +26,6 @@ "circular": false } ], - "storage": { "type": "double", "values": "some/path/depends/on/format" } + "storage": { "type": "double", "values": [1, 2, 3, 4, 5, 6, 7] } } } diff --git a/tests/test_hdf5.py b/tests/test_hdf5.py new file mode 100644 index 0000000..8c9db4a --- /dev/null +++ b/tests/test_hdf5.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +import uhi.io.json + +h5py = pytest.importorskip("h5py", reason="h5py is not installed") +uhi_io_hdf5 = pytest.importorskip("uhi.io.hdf5") + +DIR = Path(__file__).parent.resolve() + +VALID_FILES = DIR.glob("resources/valid/*.json") + + +@pytest.mark.parametrize("filename", VALID_FILES, ids=lambda p: p.name) +def test_valid_json(filename: Path, tmp_path: Path) -> None: + data = filename.read_text(encoding="utf-8") + hists = json.loads(data, object_hook=uhi.io.json.object_hook) + + tmp_file = tmp_path / "test.h5" + with h5py.File(tmp_file, "w") as h5_file: + for name, hist in hists.items(): + uhi_io_hdf5.write(h5_file.create_group(name), hist) + + with h5py.File(tmp_file, "r") as h5_file: + rehists = {name: uhi_io_hdf5.read(h5_file[name]) for name in hists} + + assert hists.keys() == rehists.keys() + + for name in hists: + hist = hists[name] + rehist = rehists[name] + + # Check that the JSON representation is the same + redata = json.dumps(hist, default=uhi.io.json.default, sort_keys=True) + data = json.dumps(rehist, default=uhi.io.json.default, sort_keys=True) + assert redata.replace(" ", "").replace("\n", "") == data.replace( + " ", "" + ).replace("\n", "") + + +def test_reg_load(tmp_path: Path) -> None: + data = DIR / "resources/valid/reg.json" + hists = json.loads( + data.read_text(encoding="utf-8"), object_hook=uhi.io.json.object_hook + ) + + tmp_file = tmp_path / "test.h5" + with h5py.File(tmp_file, "w") as h5_file: + uhi_io_hdf5.write(h5_file.create_group("one"), hists["one"]) + + with h5py.File(tmp_file, "r") as h5_file: + one = uhi_io_hdf5.read(h5_file["one"]) + + assert one["metadata"] == {"one": True, "two": 2, "three": "three"} + + assert len(one["axes"]) == 1 + assert one["axes"][0]["type"] == "regular" + assert one["axes"][0]["lower"] == pytest.approx(0) + assert one["axes"][0]["upper"] == pytest.approx(5) + assert one["axes"][0]["bins"] == 3 + assert one["axes"][0]["underflow"] + assert one["axes"][0]["overflow"] + assert not one["axes"][0]["circular"] + + assert one["storage"]["type"] == "int" + assert one["storage"]["values"] == pytest.approx([1, 2, 3, 4, 5]) diff --git a/tests/test_json.py b/tests/test_json.py new file mode 100644 index 0000000..7020df7 --- /dev/null +++ b/tests/test_json.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +import uhi.io.json + +DIR = Path(__file__).parent.resolve() + +VALID_FILES = DIR.glob("resources/valid/*.json") + + +@pytest.mark.parametrize("filename", VALID_FILES, ids=lambda p: p.name) +def test_valid_json(filename: Path) -> None: + data = filename.read_text(encoding="utf-8") + hist = json.loads(data, object_hook=uhi.io.json.object_hook) + redata = json.dumps(hist, default=uhi.io.json.default) + + rehist = json.loads(redata, object_hook=uhi.io.json.object_hook) + assert redata.replace(" ", "").replace("\n", "") == data.replace(" ", "").replace( + "\n", "" + ) + + assert hist.keys() == rehist.keys() + + +def test_reg_load() -> None: + data = DIR / "resources/valid/reg.json" + hists = json.loads( + data.read_text(encoding="utf-8"), object_hook=uhi.io.json.object_hook + ) + one = hists["one"] + two = hists["two"] + + assert one["metadata"] == {"one": True, "two": 2, "three": "three"} + + assert len(one["axes"]) == 1 + assert one["axes"][0]["type"] == "regular" + assert one["axes"][0]["lower"] == pytest.approx(0) + assert one["axes"][0]["upper"] == pytest.approx(5) + assert one["axes"][0]["bins"] == 3 + assert one["axes"][0]["underflow"] + assert one["axes"][0]["overflow"] + assert not one["axes"][0]["circular"] + + assert one["storage"]["type"] == "int" + assert one["storage"]["values"] == pytest.approx([1, 2, 3, 4, 5]) + + assert len(two["axes"]) == 1 + assert two["axes"][0]["type"] == "regular" + assert two["axes"][0]["lower"] == pytest.approx(0) + assert two["axes"][0]["upper"] == pytest.approx(5) + assert two["axes"][0]["bins"] == 5 + assert two["axes"][0]["underflow"] + assert two["axes"][0]["overflow"] + assert not two["axes"][0]["circular"] + + assert two["storage"]["type"] == "double" + assert two["storage"]["values"] == pytest.approx([1, 2, 3, 4, 5, 6, 7]) diff --git a/tests/test_zip.py b/tests/test_zip.py new file mode 100644 index 0000000..a85d103 --- /dev/null +++ b/tests/test_zip.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import copy +import json +import zipfile +from pathlib import Path + +import pytest + +import uhi.io.json +import uhi.io.zip + +DIR = Path(__file__).parent.resolve() + +VALID_FILES = DIR.glob("resources/valid/*.json") + + +@pytest.mark.parametrize("filename", VALID_FILES, ids=lambda p: p.name) +def test_valid_json(filename: Path, tmp_path: Path) -> None: + data = filename.read_text(encoding="utf-8") + hists = json.loads(data, object_hook=uhi.io.json.object_hook) + + tmp_file = tmp_path / "test.zip" + with zipfile.ZipFile(tmp_file, "w") as zip_file: + for name, hist in hists.items(): + uhi.io.zip.write(zip_file, name, copy.deepcopy(hist)) + with zipfile.ZipFile(tmp_file, "r") as zip_file: + rehists = {name: uhi.io.zip.read(zip_file, name) for name in hists} + + assert hists.keys() == rehists.keys() + + for name in hists: + hist = hists[name] + rehist = rehists[name] + + # Check that the JSON representation is the same + redata = json.dumps(hist, default=uhi.io.json.default) + data = json.dumps(rehist, default=uhi.io.json.default) + assert redata.replace(" ", "").replace("\n", "") == data.replace( + " ", "" + ).replace("\n", "") + + +def test_reg_load(tmp_path: Path) -> None: + data = DIR / "resources/valid/reg.json" + hists = json.loads( + data.read_text(encoding="utf-8"), object_hook=uhi.io.json.object_hook + ) + + tmp_file = tmp_path / "test.zip" + with zipfile.ZipFile(tmp_file, "w") as zip_file: + for name, hist in hists.items(): + uhi.io.zip.write(zip_file, name, hist) + with zipfile.ZipFile(tmp_file, "r") as zip_file: + names = zip_file.namelist() + rehists = { + name[:-5]: uhi.io.zip.read(zip_file, name[:-5]) + for name in names + if name.endswith(".json") + } + with zip_file.open("one.json") as f: + native_one = json.load(f) + + assert set(names) == { + "one_storage_values.npy", + "one.json", + "two_storage_values.npy", + "two.json", + } + + assert native_one["storage"]["values"] == "one_storage_values.npy" + + one = rehists["one"] + two = rehists["two"] + + assert one.get("metadata", {}) == {"one": True, "two": 2, "three": "three"} + + assert len(one["axes"]) == 1 + assert one["axes"][0]["type"] == "regular" + assert one["axes"][0]["lower"] == pytest.approx(0) + assert one["axes"][0]["upper"] == pytest.approx(5) + assert one["axes"][0]["bins"] == 3 + assert one["axes"][0]["underflow"] + assert one["axes"][0]["overflow"] + assert not one["axes"][0]["circular"] + + assert one["storage"]["type"] == "int" + assert one["storage"]["values"] == pytest.approx([1, 2, 3, 4, 5]) + + assert len(two["axes"]) == 1 + assert two["axes"][0]["type"] == "regular" + assert two["axes"][0]["lower"] == pytest.approx(0) + assert two["axes"][0]["upper"] == pytest.approx(5) + assert two["axes"][0]["bins"] == 5 + assert two["axes"][0]["underflow"] + assert two["axes"][0]["overflow"] + assert not two["axes"][0]["circular"] + + assert two["storage"]["type"] == "double" + assert two["storage"]["values"] == pytest.approx([1, 2, 3, 4, 5, 6, 7])