Skip to content

Commit 8709d04

Browse files
feat: add read/write support
Co-authored-by: Aryaman Jeendgar <jeendgararyaman@gmail.com> Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>
1 parent fa83337 commit 8709d04

File tree

12 files changed

+583
-23
lines changed

12 files changed

+583
-23
lines changed

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ schema = [
4747
"fastjsonschema",
4848
"importlib-resources; python_version<'3.9'",
4949
]
50+
hdf5 = [
51+
"h5py",
52+
]
5053

5154
[dependency-groups]
5255
docs = [
@@ -62,6 +65,7 @@ test = [
6265
"boost-histogram>=1.0",
6366
"fastjsonschema",
6467
"importlib-resources; python_version<'3.9'",
68+
"h5py; platform_python_implementation == 'cpython'",
6569
]
6670
dev = [{ include-group = "test"}]
6771

@@ -89,7 +93,7 @@ warn_unreachable = true
8993
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
9094

9195
[[tool.mypy.overrides]]
92-
module = ["fastjsonschema"]
96+
module = ["fastjsonschema", "h5py"]
9397
ignore_missing_imports = true
9498

9599

src/uhi/io/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
3+
__all__ = ["ARRAY_KEYS", "LIST_KEYS"]
4+
5+
ARRAY_KEYS = frozenset(
6+
[
7+
"values",
8+
"variances",
9+
"edges",
10+
"counts",
11+
"sum_of_weights",
12+
"sum_of_weights_squared",
13+
]
14+
)
15+
16+
LIST_KEYS = frozenset(
17+
[
18+
"categories",
19+
]
20+
)

src/uhi/io/hdf5.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import h5py
6+
import numpy as np
7+
8+
from ..typing.serialization import AnyAxis, AnyHistogram, AnyStorage, Histogram
9+
from . import ARRAY_KEYS
10+
11+
__all__ = ["read", "write"]
12+
13+
14+
def __dir__() -> list[str]:
15+
return __all__
16+
17+
18+
def write(grp: h5py.Group, /, histogram: AnyHistogram) -> None:
19+
"""
20+
Write a histogram to an HDF5 group.
21+
"""
22+
# All referenced objects will be stored inside of /{name}/ref_axes
23+
hist_folder_storage = grp.create_group("ref_axes")
24+
25+
# Metadata
26+
27+
if "metadata" in histogram:
28+
metadata_grp = grp.create_group("metadata")
29+
for key, val1 in histogram["metadata"].items():
30+
metadata_grp.attrs[key] = val1
31+
32+
# Axes
33+
axes_dataset = grp.create_dataset(
34+
"axes", len(histogram["axes"]), dtype=h5py.special_dtype(ref=h5py.Reference)
35+
)
36+
for i, axis in enumerate(histogram["axes"]):
37+
# Iterating through the axes, calling `create_axes_object` for each of them,
38+
# creating references to new groups and appending it to the `items` dataset defined above
39+
ax_group = hist_folder_storage.create_group(f"axis_{i}")
40+
ax_info = axis.copy()
41+
ax_metadata = ax_info.pop("metadata", None)
42+
ax_edges_raw = ax_info.pop("edges", None)
43+
ax_edges = np.asarray(ax_edges_raw) if ax_edges_raw is not None else None
44+
ax_cats: list[int] | list[str] | None = ax_info.pop("categories", None)
45+
for key, val2 in ax_info.items():
46+
ax_group.attrs[key] = val2
47+
if ax_metadata is not None:
48+
ax_metadata_grp = ax_group.create_group("metadata")
49+
for k, v in ax_metadata.items():
50+
ax_metadata_grp.attrs[k] = v
51+
if ax_edges is not None:
52+
ax_group.create_dataset("edges", shape=ax_edges.shape, data=ax_edges)
53+
if ax_cats is not None:
54+
ax_group.create_dataset("categories", shape=len(ax_cats), data=ax_cats)
55+
axes_dataset[i] = ax_group.ref
56+
57+
# Storage
58+
storage_grp = grp.create_group("storage")
59+
storage_type = histogram["storage"]["type"]
60+
61+
storage_grp.attrs["type"] = storage_type
62+
63+
for key, val3 in histogram["storage"].items():
64+
if key == "type":
65+
continue
66+
npvalue = np.asarray(val3)
67+
storage_grp.create_dataset(key, shape=npvalue.shape, data=npvalue)
68+
69+
70+
def _convert_axes(group: h5py.Group | h5py.Dataset | h5py.Datatype) -> AnyAxis:
71+
"""
72+
Convert an HDF5 axis reference to a dictionary.
73+
"""
74+
assert isinstance(group, h5py.Group)
75+
76+
axis = {k: _convert_item(k, v) for k, v in group.attrs.items()}
77+
if "edges" in group:
78+
edges = group["edges"]
79+
assert isinstance(edges, h5py.Dataset)
80+
axis["edges"] = np.asarray(edges)
81+
if "categories" in group:
82+
categories = group["categories"]
83+
assert isinstance(categories, h5py.Dataset)
84+
axis["categories"] = [_convert_item("", c) for c in categories]
85+
86+
return axis # type: ignore[return-value]
87+
88+
89+
def _convert_item(name: str, item: Any, /) -> Any:
90+
"""
91+
Convert an HDF5 item to a native Python type.
92+
"""
93+
if isinstance(item, bytes):
94+
return item.decode("utf-8")
95+
if name == "metadata":
96+
return {k: _convert_item("", v) for k, v in item.items()}
97+
if name in ARRAY_KEYS:
98+
return item
99+
if isinstance(item, np.generic):
100+
return item.item()
101+
return item
102+
103+
104+
def read(grp: h5py.Group, /) -> Histogram:
105+
"""
106+
Read a histogram from an HDF5 group.
107+
"""
108+
axes_grp = grp["axes"]
109+
axes_ref = grp["ref_axes"]
110+
assert isinstance(axes_ref, h5py.Group)
111+
assert isinstance(axes_grp, h5py.Dataset)
112+
113+
axes = [_convert_axes(axes_ref[unref_axis_ref]) for unref_axis_ref in axes_ref]
114+
115+
storage_grp = grp["storage"]
116+
assert isinstance(storage_grp, h5py.Group)
117+
storage = AnyStorage(type=storage_grp.attrs["type"])
118+
for key in storage_grp:
119+
storage[key] = np.asarray(storage_grp[key]) # type: ignore[literal-required]
120+
121+
histogram_dict = AnyHistogram(axes=axes, storage=storage)
122+
if "metadata" in grp:
123+
histogram_dict["metadata"] = _convert_item("metadata", grp["metadata"].attrs)
124+
125+
return histogram_dict # type: ignore[return-value]

src/uhi/io/json.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import numpy as np
6+
7+
from . import ARRAY_KEYS
8+
9+
__all__ = ["default", "object_hook"]
10+
11+
12+
def __dir__() -> list[str]:
13+
return __all__
14+
15+
16+
def default(obj: Any, /) -> Any:
17+
if isinstance(obj, np.ndarray):
18+
return obj.tolist() # Convert ndarray to list
19+
msg = f"Object of type {type(obj)} is not JSON serializable"
20+
raise TypeError(msg)
21+
22+
23+
def object_hook(dct: dict[str, Any], /) -> dict[str, Any]:
24+
"""
25+
Decode a histogram from a dictionary.
26+
"""
27+
28+
for item in ARRAY_KEYS & dct.keys():
29+
if isinstance(dct[item], list):
30+
dct[item] = np.asarray(dct[item])
31+
32+
return dct

src/uhi/io/zip.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import zipfile
5+
from typing import Any
6+
7+
import numpy as np
8+
9+
from ..typing.serialization import AnyHistogram, Histogram
10+
from . import ARRAY_KEYS
11+
12+
__all__ = ["read", "write"]
13+
14+
15+
def __dir__() -> list[str]:
16+
return __all__
17+
18+
19+
def write(
20+
zip_file: zipfile.ZipFile,
21+
/,
22+
name: str,
23+
histogram: AnyHistogram,
24+
) -> None:
25+
"""
26+
Write a histogram to a zip file.
27+
"""
28+
# Write out numpy arrays to files in the zipfile
29+
for storage_key in ARRAY_KEYS & histogram["storage"].keys():
30+
path = f"{name}_storage_{storage_key}.npy"
31+
with zip_file.open(path, "w") as f:
32+
np.save(f, histogram["storage"][storage_key]) # type: ignore[literal-required]
33+
histogram["storage"][storage_key] = path # type: ignore[literal-required]
34+
35+
for axis in histogram["axes"]:
36+
for key in ARRAY_KEYS & axis.keys():
37+
path = f"{name}_axis_{key}.npy"
38+
with zip_file.open(path, "w") as f:
39+
np.save(f, axis[key]) # type: ignore[literal-required]
40+
axis[key] = path # type: ignore[literal-required]
41+
42+
hist_json = json.dumps(histogram)
43+
zip_file.writestr(f"{name}.json", hist_json)
44+
45+
46+
def read(zip_file: zipfile.ZipFile, /, name: str) -> Histogram:
47+
"""
48+
Read histograms from a zip file.
49+
"""
50+
51+
def object_hook(dct: dict[str, Any], /) -> dict[str, Any]:
52+
for item in ARRAY_KEYS & dct.keys():
53+
if isinstance(dct[item], str):
54+
dct[item] = np.load(zip_file.open(dct[item]))
55+
return dct
56+
57+
with zip_file.open(f"{name}.json") as f:
58+
return json.load(f, object_hook=object_hook) # type: ignore[no-any-return]

src/uhi/resources/histogram.schema.json

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,21 @@
5252
}
5353
}
5454
},
55+
"ndarray": {
56+
"type": "array",
57+
"items": {
58+
"oneOf": [{ "type": "number" }, { "$ref": "#/$defs/ndarray" }]
59+
},
60+
"description": "A ND (nested) array of numbers."
61+
},
5562
"data_array": {
5663
"oneOf": [
5764
{
5865
"type": "string",
5966
"description": "A path (similar to URI) to the floating point bin data"
6067
},
6168
{
62-
"type": "array",
63-
"items": { "type": "number" }
69+
"$ref": "#/$defs/ndarray"
6470
}
6571
]
6672
},

0 commit comments

Comments
 (0)