Skip to content

Commit 9da27ca

Browse files
feat: add read/write support
Co-authored-by: Aryaman Jeendgar <jeendgararyaman@gmail.com> Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>
1 parent fa83337 commit 9da27ca

File tree

11 files changed

+458
-22
lines changed

11 files changed

+458
-22
lines changed

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ schema = [
4747
"fastjsonschema",
4848
"importlib-resources; python_version<'3.9'",
4949
]
50+
hdf5 = [
51+
"h5py",
52+
]
5053

5154
[dependency-groups]
5255
docs = [
@@ -62,6 +65,7 @@ test = [
6265
"boost-histogram>=1.0",
6366
"fastjsonschema",
6467
"importlib-resources; python_version<'3.9'",
68+
"h5py",
6569
]
6670
dev = [{ include-group = "test"}]
6771

@@ -89,7 +93,7 @@ warn_unreachable = true
8993
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
9094

9195
[[tool.mypy.overrides]]
92-
module = ["fastjsonschema"]
96+
module = ["fastjsonschema", "h5py"]
9397
ignore_missing_imports = true
9498

9599

src/uhi/io/__init__.py

Whitespace-only changes.

src/uhi/io/hdf5.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import annotations
2+
3+
import h5py
4+
import numpy as np
5+
6+
from ..typing.serialization import AnyAxis, AnyHistogram, AnyStorage, Histogram
7+
8+
__all__ = ["read", "write"]
9+
10+
11+
def __dir__() -> list[str]:
12+
return __all__
13+
14+
15+
def write(grp: h5py.Group, /, histogram: AnyHistogram) -> None:
16+
"""
17+
Write a histogram to an HDF5 group.
18+
"""
19+
# All referenced objects will be stored inside of /{name}/ref_axes
20+
hist_folder_storage = grp.create_group("ref_axes")
21+
22+
# Metadata
23+
24+
if "metadata" in histogram:
25+
metadata_grp = grp.create_group("metadata")
26+
for key, val1 in histogram["metadata"].items():
27+
metadata_grp.attrs[key] = val1
28+
29+
# Axes
30+
axes_dataset = grp.create_dataset(
31+
"axes", len(histogram["axes"]), dtype=h5py.special_dtype(ref=h5py.Reference)
32+
)
33+
for i, axis in enumerate(histogram["axes"]):
34+
# Iterating through the axes, calling `create_axes_object` for each of them,
35+
# creating references to new groups and appending it to the `items` dataset defined above
36+
ax_group = hist_folder_storage.create_group(f"axis_{i}")
37+
ax_info = axis.copy()
38+
ax_metadata = ax_info.pop("metadata", None)
39+
ax_edges = np.asarray(ax_info.pop("edges", None))
40+
ax_cats: list[int] | list[str] | None = ax_info.pop("categories", None)
41+
for key, val2 in ax_info.items():
42+
ax_group.attrs[key] = val2
43+
if ax_metadata is not None:
44+
ax_metadata_grp = ax_group.create_group("metadata")
45+
for k, v in ax_metadata.items():
46+
ax_metadata_grp.attrs[k] = v
47+
if ax_edges is not None:
48+
ax_group.create_dataset("edges", shape=ax_edges.shape, data=ax_edges)
49+
if ax_cats is not None:
50+
ax_group.create_dataset("items", shape=len(ax_cats), data=ax_cats)
51+
axes_dataset[i] = ax_group.ref
52+
53+
# Storage
54+
storage_grp = grp.create_group("storage")
55+
storage_type = histogram["storage"]["type"]
56+
57+
storage_grp.attrs["type"] = storage_type
58+
59+
for key, val3 in histogram["storage"].items():
60+
if key == "type":
61+
continue
62+
npvalue = np.asarray(val3)
63+
storage_grp.create_dataset(key, shape=npvalue.shape, data=npvalue)
64+
65+
66+
def read(grp: h5py.Group, /) -> Histogram:
67+
"""
68+
Read a histogram from an HDF5 group.
69+
"""
70+
axes_grp = grp["axes"]
71+
axes_ref = grp["ref_axes"]
72+
assert isinstance(axes_ref, h5py.Group)
73+
assert isinstance(axes_grp, h5py.Dataset)
74+
75+
axes = [AnyAxis(**axes_ref[unref_axis_ref].attrs) for unref_axis_ref in axes_ref] # type: ignore[typeddict-item]
76+
77+
storage_grp = grp["storage"]
78+
assert isinstance(storage_grp, h5py.Group)
79+
storage = AnyStorage(type=storage_grp.attrs["type"])
80+
for key in storage_grp:
81+
storage[key] = np.asarray(storage_grp[key]) # type: ignore[literal-required]
82+
83+
histogram_dict = AnyHistogram(axes=axes, storage=storage)
84+
if "metadata" in grp:
85+
histogram_dict["metadata"] = dict(grp["metadata"].attrs.items())
86+
87+
return histogram_dict # type: ignore[return-value]

src/uhi/io/json.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import numpy as np
6+
7+
__all__ = ["default", "object_hook"]
8+
9+
10+
CONVERT_KEYS = [
11+
"values",
12+
"variances",
13+
"edges",
14+
"counts",
15+
"sum_of_weights",
16+
"sum_of_weights_squared",
17+
]
18+
19+
20+
def __dir__() -> list[str]:
21+
return __all__
22+
23+
24+
def default(obj: Any, /) -> Any:
25+
if isinstance(obj, np.ndarray):
26+
return obj.tolist() # Convert ndarray to list
27+
msg = f"Object of type {type(obj)} is not JSON serializable"
28+
raise TypeError(msg)
29+
30+
31+
def object_hook(dct: dict[str, Any], /) -> dict[str, Any]:
32+
"""
33+
Decode a histogram from a dictionary.
34+
"""
35+
36+
for item in CONVERT_KEYS:
37+
if item in dct and isinstance(dct[item], list):
38+
dct[item] = np.asarray(dct[item])
39+
40+
return dct

src/uhi/io/zip.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import zipfile
5+
from typing import Any
6+
7+
import numpy as np
8+
9+
from ..typing.serialization import AnyHistogram, Histogram
10+
from .json import CONVERT_KEYS
11+
12+
__all__ = ["read", "write"]
13+
14+
15+
def write(
16+
zip_file: zipfile.ZipFile,
17+
/,
18+
name: str,
19+
histogram: AnyHistogram,
20+
) -> None:
21+
"""
22+
Write a histogram to a zip file.
23+
"""
24+
# Write out numpy arrays to files in the zipfile
25+
for storage_key in list(histogram["storage"]):
26+
if storage_key in CONVERT_KEYS:
27+
path = f"{name}_storage_{storage_key}.npy"
28+
np.save(zip_file.open(path, "w"), histogram["storage"][storage_key]) # type: ignore[literal-required]
29+
histogram["storage"][storage_key] = path # type: ignore[literal-required]
30+
31+
for axis in histogram["axes"]:
32+
for key in list(axis):
33+
if key in CONVERT_KEYS:
34+
path = f"{name}_axis_{key}.npy"
35+
np.save(zip_file.open(path, "w"), axis[key]) # type: ignore[literal-required]
36+
axis[key] = path # type: ignore[literal-required]
37+
38+
hist_json = json.dumps(histogram)
39+
zip_file.writestr(f"{name}.json", hist_json)
40+
41+
42+
def read(zip_file: zipfile.ZipFile, /, name: str) -> Histogram:
43+
"""
44+
Read histograms from a zip file.
45+
"""
46+
47+
def object_hook(dct: dict[str, Any], /) -> dict[str, Any]:
48+
for item in CONVERT_KEYS:
49+
if item in dct and isinstance(dct[item], str):
50+
dct[item] = np.load(zip_file.open(dct[item]))
51+
return dct
52+
53+
with zip_file.open(f"{name}.json") as f:
54+
return json.load(f, object_hook=object_hook) # type: ignore[no-any-return]

src/uhi/resources/histogram.schema.json

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,21 @@
5252
}
5353
}
5454
},
55+
"ndarray": {
56+
"type": "array",
57+
"items": {
58+
"oneOf": [{ "type": "number" }, { "$ref": "#/$defs/ndarray" }]
59+
},
60+
"description": "A ND (nested) array of numbers."
61+
},
5562
"data_array": {
5663
"oneOf": [
5764
{
5865
"type": "string",
5966
"description": "A path (similar to URI) to the floating point bin data"
6067
},
6168
{
62-
"type": "array",
63-
"items": { "type": "number" }
69+
"$ref": "#/$defs/ndarray"
6470
}
6571
]
6672
},

src/uhi/typing/serialization.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
1+
"""Serialization types for UHI.
2+
3+
Two types of dictionaries are defined here:
4+
5+
1. ``AnyAxis``, ``AnyStorage``, and ``AnyHistogram`` are used for inputs. They represent
6+
the merger of all possible types.
7+
2. ``Axis``, ``Storage``, and ``histogram`` are used for outputs. These have precise entries
8+
defined for each Literal type.
9+
"""
10+
111
from __future__ import annotations
212

3-
from collections.abc import Sequence
413
from typing import Literal, TypedDict, Union
514

15+
from numpy.typing import ArrayLike
16+
617
__all__ = [
18+
"AnyAxis",
19+
"AnyHistogram",
20+
"AnyStorage",
21+
"Axis",
722
"BooleanAxis",
823
"CategoryIntAxis",
924
"CategoryStrAxis",
@@ -12,6 +27,7 @@
1227
"IntStorage",
1328
"MeanStorage",
1429
"RegularAxis",
30+
"Storage",
1531
"VariableAxis",
1632
"WeightedMeanStorage",
1733
"WeightedStorage",
@@ -40,7 +56,7 @@ class RegularAxis(_RequiredRegularAxis, total=False):
4056

4157
class _RequiredVariableAxis(TypedDict):
4258
type: Literal["variable"]
43-
edges: list[float] | str
59+
edges: ArrayLike | str
4460
underflow: bool
4561
overflow: bool
4662
circular: bool
@@ -80,43 +96,84 @@ class BooleanAxis(_RequiredBooleanAxis, total=False):
8096

8197
class IntStorage(TypedDict):
8298
type: Literal["int"]
83-
values: Sequence[int] | str
99+
values: ArrayLike | str
84100

85101

86102
class DoubleStorage(TypedDict):
87103
type: Literal["double"]
88-
values: Sequence[float] | str
104+
values: ArrayLike | str
89105

90106

91107
class WeightedStorage(TypedDict):
92108
type: Literal["weighted"]
93-
values: Sequence[float] | str
94-
variances: Sequence[float] | str
109+
values: ArrayLike | str
110+
variances: ArrayLike | str
95111

96112

97113
class MeanStorage(TypedDict):
98114
type: Literal["mean"]
99-
counts: Sequence[float] | str
100-
values: Sequence[float] | str
101-
variances: Sequence[float] | str
115+
counts: ArrayLike | str
116+
values: ArrayLike | str
117+
variances: ArrayLike | str
102118

103119

104120
class WeightedMeanStorage(TypedDict):
105121
type: Literal["weighted_mean"]
106-
sum_of_weights: Sequence[float] | str
107-
sum_of_weights_squared: Sequence[float] | str
108-
values: Sequence[float] | str
109-
variances: Sequence[float] | str
122+
sum_of_weights: ArrayLike | str
123+
sum_of_weights_squared: ArrayLike | str
124+
values: ArrayLike | str
125+
variances: ArrayLike | str
126+
127+
128+
Storage = Union[
129+
IntStorage, DoubleStorage, WeightedStorage, MeanStorage, WeightedMeanStorage
130+
]
131+
132+
Axis = Union[RegularAxis, VariableAxis, CategoryStrAxis, CategoryIntAxis, BooleanAxis]
133+
134+
135+
class _RequiredAnyStorage(TypedDict):
136+
type: Literal["int", "double", "weighted", "mean", "weighted_mean"]
137+
138+
139+
class AnyStorage(_RequiredAnyStorage, total=False):
140+
values: ArrayLike | str
141+
variances: ArrayLike | str
142+
sum_of_weights: ArrayLike | str
143+
sum_of_weights_squared: ArrayLike | str
144+
counts: ArrayLike | str
145+
146+
147+
class _RequiredAnyAxis(TypedDict):
148+
type: Literal["regular", "variable", "category_str", "category_int", "boolean"]
149+
150+
151+
class AnyAxis(_RequiredAnyAxis, total=False):
152+
metadata: dict[str, SupportedMetadata]
153+
lower: float
154+
upper: float
155+
bins: int
156+
edges: ArrayLike | str
157+
categories: list[str] | list[int]
158+
underflow: bool
159+
overflow: bool
160+
flow: bool
161+
circular: bool
110162

111163

112164
class _RequiredHistogram(TypedDict):
113-
axes: list[
114-
RegularAxis | VariableAxis | CategoryStrAxis | CategoryIntAxis | BooleanAxis
115-
]
116-
storage: (
117-
IntStorage | DoubleStorage | WeightedStorage | MeanStorage | WeightedMeanStorage
118-
)
165+
axes: list[Axis]
166+
storage: Storage
119167

120168

121169
class Histogram(_RequiredHistogram, total=False):
122170
metadata: dict[str, SupportedMetadata]
171+
172+
173+
class _RequiredAnyHistogram(TypedDict):
174+
axes: list[AnyAxis]
175+
storage: AnyStorage
176+
177+
178+
class AnyHistogram(_RequiredAnyHistogram, total=False):
179+
metadata: dict[str, SupportedMetadata]

0 commit comments

Comments
 (0)