Skip to content

Commit 6ea4adf

Browse files
authored
tcm handle multiple cols, change to vov (#599)
* tcm to vov and use iterators so can handle large files * update evt for new tcm, and cleanup * update legendtestdata commit * handle multiple instance of a channel in an evt/tcm entry * use pytest tmpdir * fix naming to table_key and row_in_table
1 parent a772619 commit 6ea4adf

21 files changed

+1030
-600
lines changed

src/pygama/evt/aggregators.py

Lines changed: 97 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import awkward as ak
88
import numpy as np
9+
import pandas as pd
910
from lgdo import lh5, types
10-
from lgdo.lh5 import LH5Store
1111

1212
from . import utils
1313

@@ -58,15 +58,14 @@ def evaluate_to_first_or_last(
5858
"""
5959
f = utils.make_files_config(datainfo)
6060

61-
out = None
62-
outt = None
63-
store = LH5Store(keep_open=True)
61+
df = None
6462

6563
for ch in channels:
6664
table_id = utils.get_tcm_id_by_pattern(f.hit.table_fmt, ch)
6765

6866
# get index list for this channel to be loaded
69-
idx_ch = tcm.idx[tcm.id == table_id]
67+
chan_tcm_indexs = ak.flatten(tcm.table_key) == table_id
68+
idx_ch = ak.flatten(tcm.row_in_table)[chan_tcm_indexs].to_numpy()
7069

7170
# evaluate at channel
7271
if ch not in channels_skip:
@@ -79,58 +78,52 @@ def evaluate_to_first_or_last(
7978
pars_dict=pars_dict,
8079
)
8180

82-
if out is None:
81+
if df is None:
8382
# define dimension of output array
8483
out = utils.make_numpy_full(n_rows, default_value, res.dtype)
85-
outt = np.zeros(len(out))
86-
else:
87-
res = np.full(len(idx_ch), default_value)
88-
89-
# get mask from query
90-
limarr = utils.get_mask_from_query(
91-
datainfo=datainfo,
92-
query=query,
93-
length=len(res),
94-
ch=ch,
95-
idx_ch=idx_ch,
96-
)
84+
df = pd.DataFrame({"sort_field": np.zeros(len(out)), "res": out})
9785

98-
# find if sorter is in hit or dsp
99-
t0 = store.read(
100-
f"{ch}/{sorter[0]}/{sorter[1]}",
101-
f.hit.file if f"{f.hit.group}" == sorter[0] else f.dsp.file,
102-
idx=idx_ch,
103-
)[0].view_as("np")
86+
# get mask from query
87+
limarr = utils.get_mask_from_query(
88+
datainfo=datainfo,
89+
query=query,
90+
length=len(res),
91+
ch=ch,
92+
idx_ch=idx_ch,
93+
)
10494

105-
if t0.ndim > 1:
106-
raise ValueError(f"sorter '{sorter[0]}/{sorter[1]}' must be a 1D array")
95+
# find if sorter is in hit or dsp
96+
sort_field = lh5.read_as(
97+
f"{ch}/{sorter[0]}/{sorter[1]}",
98+
f.hit.file if f"{f.hit.group}" == sorter[0] else f.dsp.file,
99+
idx=idx_ch,
100+
library="np",
101+
)
107102

108-
evt_ids_ch = np.searchsorted(
109-
tcm.cumulative_length,
110-
np.where(tcm.id == table_id)[0],
111-
"right",
112-
)
103+
if sort_field.ndim > 1:
104+
raise ValueError(f"sorter '{sorter[0]}/{sorter[1]}' must be a 1D array")
113105

114-
if is_first:
115-
if ch == channels[0]:
116-
outt[:] = np.inf
106+
ch_df = pd.DataFrame({"sort_field": sort_field, "res": res})
117107

118-
out[evt_ids_ch] = np.where(
119-
(t0 < outt[evt_ids_ch]) & (limarr), res, out[evt_ids_ch]
120-
)
121-
outt[evt_ids_ch] = np.where(
122-
(t0 < outt[evt_ids_ch]) & (limarr), t0, outt[evt_ids_ch]
108+
evt_ids_ch = np.repeat(
109+
np.arange(0, len(tcm.table_key)),
110+
ak.sum(tcm.table_key == table_id, axis=1),
123111
)
124112

125-
else:
126-
out[evt_ids_ch] = np.where(
127-
(t0 > outt[evt_ids_ch]) & (limarr), res, out[evt_ids_ch]
128-
)
129-
outt[evt_ids_ch] = np.where(
130-
(t0 > outt[evt_ids_ch]) & (limarr), t0, outt[evt_ids_ch]
131-
)
113+
if is_first:
114+
if ch == channels[0]:
115+
df["sort_field"] = np.inf
116+
ids = (
117+
ch_df.sort_field.to_numpy() < df.sort_field[evt_ids_ch].to_numpy()
118+
) & (limarr)
119+
else:
120+
ids = (
121+
ch_df.sort_field.to_numpy() > df.sort_field[evt_ids_ch].to_numpy()
122+
) & (limarr)
132123

133-
return types.Array(nda=out)
124+
df.loc[evt_ids_ch[ids], list(df.columns)] = ch_df.loc[ids, list(df.columns)]
125+
126+
return types.Array(nda=df.res.to_numpy())
134127

135128

136129
def evaluate_to_scalar(
@@ -180,7 +173,8 @@ def evaluate_to_scalar(
180173
table_id = utils.get_tcm_id_by_pattern(f.hit.table_fmt, ch)
181174

182175
# get index list for this channel to be loaded
183-
idx_ch = tcm.idx[tcm.id == table_id]
176+
chan_tcm_indexs = ak.flatten(tcm.table_key) == table_id
177+
idx_ch = ak.flatten(tcm.row_in_table)[chan_tcm_indexs].to_numpy()
184178

185179
if ch not in channels_skip:
186180
res = utils.get_data_at_channel(
@@ -195,42 +189,37 @@ def evaluate_to_scalar(
195189
if out is None:
196190
# define dimension of output array
197191
out = utils.make_numpy_full(n_rows, default_value, res.dtype)
198-
else:
199-
res = np.full(len(idx_ch), default_value)
200-
201-
# get mask from query
202-
limarr = utils.get_mask_from_query(
203-
datainfo=datainfo,
204-
query=query,
205-
length=len(res),
206-
ch=ch,
207-
idx_ch=idx_ch,
208-
)
209-
210-
evt_ids_ch = np.searchsorted(
211-
tcm.cumulative_length,
212-
np.where(tcm.id == table_id)[0],
213-
side="right",
214-
)
215-
216-
# switch through modes
217-
if "sum" == mode:
218-
if res.dtype == bool:
219-
res = res.astype(int)
220192

221-
out[evt_ids_ch] = np.where(limarr, res + out[evt_ids_ch], out[evt_ids_ch])
193+
# get mask from query
194+
limarr = utils.get_mask_from_query(
195+
datainfo=datainfo,
196+
query=query,
197+
length=len(res),
198+
ch=ch,
199+
idx_ch=idx_ch,
200+
)
222201

223-
if "any" == mode:
224-
if res.dtype != bool:
225-
res = res.astype(bool)
202+
evt_ids_ch = np.repeat(
203+
np.arange(0, len(tcm.table_key)),
204+
ak.sum(tcm.table_key == table_id, axis=1),
205+
)
226206

227-
out[evt_ids_ch] = out[evt_ids_ch] | (res & limarr)
207+
# switch through modes
208+
if mode == "sum":
209+
if res.dtype == bool:
210+
res = res.astype(int)
211+
if out.dtype == bool:
212+
out = out.astype(int)
213+
out[evt_ids_ch[limarr]] += res[limarr]
214+
else:
215+
if res.dtype != bool:
216+
res = res.astype(bool)
228217

229-
if "all" == mode:
230-
if res.dtype != bool:
231-
res = res.astype(bool)
218+
if mode == "any":
219+
out[evt_ids_ch] |= res & limarr
232220

233-
out[evt_ids_ch] = out[evt_ids_ch] & res & limarr
221+
if mode == "all":
222+
out[evt_ids_ch] &= res & limarr
234223

235224
return types.Array(nda=out)
236225

@@ -274,16 +263,20 @@ def evaluate_at_channel(
274263

275264
out = None
276265

277-
for ch in np.unique(ch_comp.nda.astype(int)):
278-
table_name = utils.get_table_name_by_pattern(table_id_fmt, ch)
266+
for table_id in np.unique(ch_comp.nda.astype(int)):
267+
table_name = utils.get_table_name_by_pattern(table_id_fmt, table_id)
279268
# skip default value
280269
if table_name not in lh5.ls(f.hit.file):
281270
continue
282271

283-
idx_ch = tcm.idx[tcm.id == ch]
284-
evt_ids_ch = np.searchsorted(
285-
tcm.cumulative_length, np.where(tcm.id == ch)[0], "right"
272+
# get index list for this channel to be loaded
273+
chan_tcm_indexs = ak.flatten(tcm.table_key) == table_id
274+
idx_ch = ak.flatten(tcm.row_in_table)[chan_tcm_indexs].to_numpy()
275+
276+
evt_ids_ch = np.repeat(
277+
np.arange(0, len(tcm.table_key)), ak.sum(tcm.table_key == table_id, axis=1)
286278
)
279+
287280
if (table_name in channels) and (table_name not in channels_skip):
288281
res = utils.get_data_at_channel(
289282
datainfo=datainfo,
@@ -299,7 +292,9 @@ def evaluate_at_channel(
299292
if out is None:
300293
out = utils.make_numpy_full(len(ch_comp.nda), default_value, res.dtype)
301294

302-
out[evt_ids_ch] = np.where(ch == ch_comp.nda[idx_ch], res, out[evt_ids_ch])
295+
out[evt_ids_ch] = np.where(
296+
table_id == ch_comp.nda[idx_ch], res, out[evt_ids_ch]
297+
)
303298

304299
return types.Array(nda=out)
305300

@@ -348,10 +343,10 @@ def evaluate_at_channel_vov(
348343
)
349344

350345
type_name = None
351-
for ch in ch_comp_channels:
352-
table_name = utils.get_table_name_by_pattern(f.hit.table_fmt, ch)
353-
evt_ids_ch = np.searchsorted(
354-
tcm.cumulative_length, np.where(tcm.id == ch)[0], "right"
346+
for table_id in ch_comp_channels:
347+
table_name = utils.get_table_name_by_pattern(f.hit.table_fmt, table_id)
348+
evt_ids_ch = np.repeat(
349+
np.arange(0, len(tcm.table_key)), ak.sum(tcm.table_key == table_id, axis=1)
355350
)
356351
if (table_name in channels) and (table_name not in channels_skip):
357352
res = utils.get_data_at_channel(
@@ -362,20 +357,19 @@ def evaluate_at_channel_vov(
362357
field_list=field_list,
363358
pars_dict=pars_dict,
364359
)
365-
new_evt_ids_ch = np.searchsorted(
366-
ch_comp.cumulative_length,
367-
np.where(ch_comp.flattened_data.nda == ch)[0],
368-
"right",
360+
new_evt_ids_ch = np.repeat(
361+
np.arange(0, len(ch_comp)),
362+
ak.sum(ch_comp.view_as("ak") == table_id, axis=1),
369363
)
370364
matches = np.isin(evt_ids_ch, new_evt_ids_ch)
371-
out[ch_comp.flattened_data.nda == ch] = res[matches]
365+
out[ch_comp.flattened_data.nda == table_id] = res[matches]
372366

373367
else:
374-
length = len(np.where(ch_comp.flattened_data.nda == ch)[0])
368+
length = len(np.where(ch_comp.flattened_data.nda == table_id)[0])
375369
res = np.full(length, default_value)
376-
out[ch_comp.flattened_data.nda == ch] = res
370+
out[ch_comp.flattened_data.nda == table_id] = res
377371

378-
if ch == ch_comp_channels[0]:
372+
if table_id == ch_comp_channels[0]:
379373
out = out.astype(res.dtype)
380374
type_name = res.dtype
381375

@@ -438,12 +432,13 @@ def evaluate_to_aoesa(
438432

439433
for i, ch in enumerate(channels):
440434
table_id = utils.get_tcm_id_by_pattern(f.hit.table_fmt, ch)
441-
idx_ch = tcm.idx[tcm.id == table_id]
442435

443-
evt_ids_ch = np.searchsorted(
444-
tcm.cumulative_length,
445-
np.where(tcm.id == table_id)[0],
446-
"right",
436+
# get index list for this channel to be loaded
437+
chan_tcm_indexs = ak.flatten(tcm.table_key) == table_id
438+
idx_ch = ak.flatten(tcm.row_in_table)[chan_tcm_indexs].to_numpy()
439+
440+
evt_ids_ch = np.repeat(
441+
np.arange(0, len(tcm.table_key)), ak.sum(tcm.table_key == table_id, axis=1)
447442
)
448443

449444
if ch not in channels_skip:

0 commit comments

Comments
 (0)