Skip to content

Commit 4429239

Browse files
committed
Adding CI tests to typetracer builder, get_necessary_branches.
Test on root file
1 parent 60596c0 commit 4429239

File tree

1 file changed

+134
-6
lines changed

1 file changed

+134
-6
lines changed

tests/test_typetracer.py

Lines changed: 134 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,146 @@
11
import awkward as ak
2+
from awkward.forms import ListOffsetForm, NumpyForm
23
import pytest
3-
44
from servicex_analysis_utils import read_buffers
5+
from skhep_testdata import data_path
6+
import uproot
57

68

7-
def test_simple_record_typetracer():
9+
@pytest.fixture
10+
def setup_form():
11+
# Create a simple awkward array with a RecordForm
812
arr = ak.Array([{"x": [1, 2, 3], "y": [4, 5]}])
9-
1013
form = arr.layout.form
11-
tracer, report = read_buffers.build_typetracer(form)
14+
15+
# Build a form from a .root file
16+
file = data_path("uproot-Zmumu.root") + ":events"
17+
array = uproot.open(file).arrays(library="ak")
18+
return form, array.layout.form
19+
20+
21+
@pytest.fixture
22+
def setup_type_tracer(setup_form, from_uproot):
23+
array_form, uproot_form = setup_form
24+
if from_uproot:
25+
form = uproot_form
26+
else:
27+
form = array_form
28+
# Build a typetracer with the form
29+
tracer, report = read_buffers.build_typetracer_with_report(form)
30+
return tracer, report
31+
32+
33+
def test_instance_of_record_form(setup_form):
34+
simple_form, root_form = setup_form
35+
# Check instances of RecordForm
36+
assert isinstance(
37+
simple_form, ak.forms.RecordForm
38+
), f"Form is {type(simple_form)}, but should be RecordForm"
39+
40+
assert isinstance(
41+
root_form, ak.forms.RecordForm
42+
), f"Form is {type(root_form)}, but should be RecordForm"
43+
44+
45+
@pytest.mark.parametrize("from_uproot", [True, False])
46+
def test_built_typetracer_instance(setup_type_tracer):
47+
tracer, report = setup_type_tracer
48+
# Check if the tracer is an instance of high-level Array
49+
assert isinstance(
50+
tracer, ak.highlevel.Array
51+
), f"Tracer should be a highlevel Array but is {type(tracer)}"
52+
53+
# Check if the report is an instance of Report
54+
assert isinstance(report, ak.typetracer.TypeTracerReport), "Report is not a Report"
55+
# Check if the report has data_touched
56+
assert hasattr(
57+
report, "data_touched"
58+
), "Report does not have data_touched attribute"
59+
60+
61+
@pytest.mark.parametrize("from_uproot", [False])
62+
def test_simple_record_typetracer(setup_type_tracer):
63+
tracer, report = setup_type_tracer
64+
65+
assert tracer.fields == ["x", "y"], "Fields of the tracer do not match expected"
1266

1367
# “Touch” one of the two fields
1468
_ = tracer["x"] + 0
1569

16-
# Collect the branches and assert exactly one branch is needed (x)
17-
touched_branches = read_buffers.necessary_branches(report)
70+
# Collect the touched branches
71+
touched_branches = read_buffers.get_necessary_branches(report)
1872
assert set(touched_branches) == {"x"}
73+
74+
75+
@pytest.mark.parametrize("from_uproot", [True])
76+
def test_root_record_typetracer(setup_type_tracer):
77+
tracer, report = setup_type_tracer
78+
expected_fields = [
79+
"Type",
80+
"Run",
81+
"Event",
82+
"E1",
83+
"px1",
84+
"py1",
85+
"pz1",
86+
"pt1",
87+
"eta1",
88+
"phi1",
89+
"Q1",
90+
"E2",
91+
"px2",
92+
"py2",
93+
"pz2",
94+
"pt2",
95+
"eta2",
96+
"phi2",
97+
"Q2",
98+
"M",
99+
]
100+
101+
assert (
102+
tracer.fields == expected_fields
103+
), "Fields of the tracer do not match expected"
104+
105+
# Compute deltaR on typetracer
106+
delta_r = (
107+
(tracer["eta1"] - tracer["eta2"]) ** 2 + (tracer["phi1"] - tracer["phi2"]) ** 2
108+
) ** 0.5
109+
110+
# Check delta_r is still an array
111+
assert isinstance(
112+
delta_r, ak.highlevel.Array
113+
), f"delta_r should be a highlevel Array but is {type(delta_r)}"
114+
115+
# Check no data is loaded in computation
116+
# only TypeTracer placeholders should be present
117+
first_element = delta_r[0]
118+
assert "TypeTracer" in repr(
119+
first_element
120+
), f"Expected a TypeTracer placeholder, got {repr(first_element)} instead."
121+
122+
# Collect the touched branches
123+
touched_branches = read_buffers.get_necessary_branches(report)
124+
assert set(touched_branches) == {
125+
"eta1",
126+
"phi1",
127+
"eta2",
128+
"phi2",
129+
}, "Touched branches do not match expected branches"
130+
131+
132+
def test_error_on_wrong_form():
133+
# Test that ValueError is raised when an empty form is passed
134+
with pytest.raises(
135+
ValueError,
136+
match="Unsupported form type: <class 'NoneType'>. This function only supports RecordForm.",
137+
):
138+
read_buffers.add_keys(None)
139+
140+
# Test that ValueError is raised when another form type is passed
141+
dummy = ListOffsetForm("i64", NumpyForm("int32"))
142+
with pytest.raises(
143+
ValueError,
144+
match="Unsupported form type: <class 'awkward.forms.listoffsetform.ListOffsetForm'>. This function only supports RecordForm.",
145+
):
146+
read_buffers.add_keys(dummy)

0 commit comments

Comments
 (0)