Skip to content

Commit 68cebf5

Browse files
authored
Merge pull request #590 from ggmarshall/evt_changes
make evt more generic to handle new tiers
2 parents 981877e + 2d143cd commit 68cebf5

File tree

1 file changed

+56
-43
lines changed

1 file changed

+56
-43
lines changed

src/pygama/evt/utils.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,22 @@
1616
H5DataLoc = namedtuple(
1717
"H5DataLoc", ("file", "group", "table_fmt"), defaults=3 * (None,)
1818
)
19-
20-
DataInfo = namedtuple(
21-
"DataInfo", ("raw", "tcm", "dsp", "hit", "evt"), defaults=5 * (None,)
22-
)
19+
DataInfo = namedtuple("DataInfo", ("raw", "tcm", "evt"), defaults=3 * (None,))
2320

2421
TCMData = namedtuple("TCMData", ("id", "idx", "cumulative_length"))
2522

2623

2724
def make_files_config(data: dict):
28-
if not isinstance(data, DataInfo):
25+
if not isinstance(data, tuple):
26+
if "raw" not in data:
27+
data["raw"] = (None,)
28+
if "tcm" not in data:
29+
data["tcm"] = (None,)
30+
if "evt" not in data:
31+
data["evt"] = (None,)
32+
DataInfo = namedtuple(
33+
"DataInfo", tuple(data.keys()), defaults=len(data.keys()) * (None,)
34+
)
2935
return DataInfo(
3036
*[
3137
H5DataLoc(*data[tier]) if tier in data else H5DataLoc()
@@ -72,7 +78,7 @@ def find_parameters(
7278
idx_ch,
7379
field_list,
7480
) -> dict:
75-
"""Finds and returns parameters from `hit` and `dsp` tiers.
81+
"""Finds and returns parameters from non `tcm`, `evt` tiers.
7682
7783
Parameters
7884
----------
@@ -83,43 +89,38 @@ def find_parameters(
8389
idx_ch
8490
index array of entries to be read from datainfo.
8591
field_list
86-
list of tuples ``(tier, field)`` to be found in the `hit/dsp` tiers.
92+
list of tuples ``(tier, field)`` to be found in non `tcm`, `evt` tiers.
8793
"""
8894
f = make_files_config(datainfo)
8995

90-
# find fields in either dsp, hit
91-
dsp_flds = [e[1] for e in field_list if e[0] == f.dsp.group]
92-
hit_flds = [e[1] for e in field_list if e[0] == f.hit.group]
96+
final_dict = {}
9397

94-
hit_dict, dsp_dict = {}, {}
98+
for name, tier in f._asdict().items():
99+
if name not in ["tcm", "evt"] and tier.file is not None: # skip other tables
100+
keys = [
101+
k.split("/")[-1]
102+
for k in lh5.ls(tier.file, f"{ch.replace('/', '')}/{tier.group}/")
103+
]
104+
flds = [e[1] for e in field_list if e[0] == name and e[1] in keys]
95105

96-
if len(hit_flds) > 0:
97-
hit_ak = lh5.read_as(
98-
f"{ch.replace('/', '')}/{f.hit.group}/",
99-
f.hit.file,
100-
field_mask=hit_flds,
101-
idx=idx_ch,
102-
library="ak",
103-
)
106+
if len(flds) > 0:
107+
tier_ak = lh5.read_as(
108+
f"{ch.replace('/', '')}/{tier.group}/",
109+
tier.file,
110+
field_mask=flds,
111+
idx=idx_ch,
112+
library="ak",
113+
)
104114

105-
hit_dict = dict(
106-
zip([f"{f.hit.group}_" + e for e in ak.fields(hit_ak)], ak.unzip(hit_ak))
107-
)
115+
tier_dict = dict(
116+
zip(
117+
[f"{name}_" + e for e in ak.fields(tier_ak)],
118+
ak.unzip(tier_ak),
119+
)
120+
)
121+
final_dict = final_dict | tier_dict
108122

109-
if len(dsp_flds) > 0:
110-
dsp_ak = lh5.read_as(
111-
f"{ch.replace('/', '')}/{f.dsp.group}/",
112-
f.dsp.file,
113-
field_mask=dsp_flds,
114-
idx=idx_ch,
115-
library="ak",
116-
)
117-
118-
dsp_dict = dict(
119-
zip([f"{f.dsp.group}_" + e for e in ak.fields(dsp_ak)], ak.unzip(dsp_ak))
120-
)
121-
122-
return hit_dict | dsp_dict
123+
return final_dict
123124

124125

125126
def get_data_at_channel(
@@ -178,10 +179,16 @@ def get_data_at_channel(
178179

179180
# evaluate expression
180181
# move tier+dots in expression to underscores (e.g. evt.foo -> evt_foo)
182+
183+
new_expr = expr
184+
for name in f._asdict():
185+
if name == "evt":
186+
new_expr = new_expr.replace(f"{name}.", "")
187+
elif name not in ["tcm", "raw"]:
188+
new_expr = new_expr.replace(f"{name}.", f"{name}_")
189+
181190
res = eval(
182-
expr.replace(f"{f.dsp.group}.", f"{f.dsp.group}_")
183-
.replace(f"{f.hit.group}.", f"{f.hit.group}_")
184-
.replace(f"{f.evt.group}.", ""),
191+
new_expr,
185192
var,
186193
)
187194

@@ -231,17 +238,23 @@ def get_mask_from_query(
231238

232239
# get sub evt based query condition if needed
233240
if isinstance(query, str):
234-
query_lst = re.findall(r"(hit|dsp).([a-zA-Z_$][\w$]*)", query)
241+
query_lst = re.findall(
242+
rf"({'|'.join(f._asdict().keys())}).([a-zA-Z_$][\w$]*)", query
243+
)
235244
query_var = find_parameters(
236245
datainfo=datainfo,
237246
ch=ch,
238247
idx_ch=idx_ch,
239248
field_list=query_lst,
240249
)
250+
251+
new_query = query
252+
for name in f._asdict():
253+
if name not in ["tcm", "evt"]:
254+
new_query = new_query.replace(f"{name}.", f"{name}_")
255+
241256
limarr = eval(
242-
query.replace(f"{f.dsp.group}.", f"{f.dsp.group}_").replace(
243-
f"{f.hit.group}.", f"{f.hit.group}_"
244-
),
257+
new_query,
245258
query_var,
246259
)
247260

0 commit comments

Comments
 (0)