6
6
7
7
import awkward as ak
8
8
import numpy as np
9
+ import pandas as pd
9
10
from lgdo import lh5 , types
10
- from lgdo .lh5 import LH5Store
11
11
12
12
from . import utils
13
13
@@ -58,15 +58,14 @@ def evaluate_to_first_or_last(
58
58
"""
59
59
f = utils .make_files_config (datainfo )
60
60
61
- out = None
62
- outt = None
63
- store = LH5Store (keep_open = True )
61
+ df = None
64
62
65
63
for ch in channels :
66
64
table_id = utils .get_tcm_id_by_pattern (f .hit .table_fmt , ch )
67
65
68
66
# get index list for this channel to be loaded
69
- idx_ch = tcm .idx [tcm .id == table_id ]
67
+ chan_tcm_indexs = ak .flatten (tcm .table_key ) == table_id
68
+ idx_ch = ak .flatten (tcm .row_in_table )[chan_tcm_indexs ].to_numpy ()
70
69
71
70
# evaluate at channel
72
71
if ch not in channels_skip :
@@ -79,58 +78,52 @@ def evaluate_to_first_or_last(
79
78
pars_dict = pars_dict ,
80
79
)
81
80
82
- if out is None :
81
+ if df is None :
83
82
# define dimension of output array
84
83
out = utils .make_numpy_full (n_rows , default_value , res .dtype )
85
- outt = np .zeros (len (out ))
86
- else :
87
- res = np .full (len (idx_ch ), default_value )
88
-
89
- # get mask from query
90
- limarr = utils .get_mask_from_query (
91
- datainfo = datainfo ,
92
- query = query ,
93
- length = len (res ),
94
- ch = ch ,
95
- idx_ch = idx_ch ,
96
- )
84
+ df = pd .DataFrame ({"sort_field" : np .zeros (len (out )), "res" : out })
97
85
98
- # find if sorter is in hit or dsp
99
- t0 = store .read (
100
- f"{ ch } /{ sorter [0 ]} /{ sorter [1 ]} " ,
101
- f .hit .file if f"{ f .hit .group } " == sorter [0 ] else f .dsp .file ,
102
- idx = idx_ch ,
103
- )[0 ].view_as ("np" )
86
+ # get mask from query
87
+ limarr = utils .get_mask_from_query (
88
+ datainfo = datainfo ,
89
+ query = query ,
90
+ length = len (res ),
91
+ ch = ch ,
92
+ idx_ch = idx_ch ,
93
+ )
104
94
105
- if t0 .ndim > 1 :
106
- raise ValueError (f"sorter '{ sorter [0 ]} /{ sorter [1 ]} ' must be a 1D array" )
95
+ # find if sorter is in hit or dsp
96
+ sort_field = lh5 .read_as (
97
+ f"{ ch } /{ sorter [0 ]} /{ sorter [1 ]} " ,
98
+ f .hit .file if f"{ f .hit .group } " == sorter [0 ] else f .dsp .file ,
99
+ idx = idx_ch ,
100
+ library = "np" ,
101
+ )
107
102
108
- evt_ids_ch = np .searchsorted (
109
- tcm .cumulative_length ,
110
- np .where (tcm .id == table_id )[0 ],
111
- "right" ,
112
- )
103
+ if sort_field .ndim > 1 :
104
+ raise ValueError (f"sorter '{ sorter [0 ]} /{ sorter [1 ]} ' must be a 1D array" )
113
105
114
- if is_first :
115
- if ch == channels [0 ]:
116
- outt [:] = np .inf
106
+ ch_df = pd .DataFrame ({"sort_field" : sort_field , "res" : res })
117
107
118
- out [evt_ids_ch ] = np .where (
119
- (t0 < outt [evt_ids_ch ]) & (limarr ), res , out [evt_ids_ch ]
120
- )
121
- outt [evt_ids_ch ] = np .where (
122
- (t0 < outt [evt_ids_ch ]) & (limarr ), t0 , outt [evt_ids_ch ]
108
+ evt_ids_ch = np .repeat (
109
+ np .arange (0 , len (tcm .table_key )),
110
+ ak .sum (tcm .table_key == table_id , axis = 1 ),
123
111
)
124
112
125
- else :
126
- out [evt_ids_ch ] = np .where (
127
- (t0 > outt [evt_ids_ch ]) & (limarr ), res , out [evt_ids_ch ]
128
- )
129
- outt [evt_ids_ch ] = np .where (
130
- (t0 > outt [evt_ids_ch ]) & (limarr ), t0 , outt [evt_ids_ch ]
131
- )
113
+ if is_first :
114
+ if ch == channels [0 ]:
115
+ df ["sort_field" ] = np .inf
116
+ ids = (
117
+ ch_df .sort_field .to_numpy () < df .sort_field [evt_ids_ch ].to_numpy ()
118
+ ) & (limarr )
119
+ else :
120
+ ids = (
121
+ ch_df .sort_field .to_numpy () > df .sort_field [evt_ids_ch ].to_numpy ()
122
+ ) & (limarr )
132
123
133
- return types .Array (nda = out )
124
+ df .loc [evt_ids_ch [ids ], list (df .columns )] = ch_df .loc [ids , list (df .columns )]
125
+
126
+ return types .Array (nda = df .res .to_numpy ())
134
127
135
128
136
129
def evaluate_to_scalar (
@@ -180,7 +173,8 @@ def evaluate_to_scalar(
180
173
table_id = utils .get_tcm_id_by_pattern (f .hit .table_fmt , ch )
181
174
182
175
# get index list for this channel to be loaded
183
- idx_ch = tcm .idx [tcm .id == table_id ]
176
+ chan_tcm_indexs = ak .flatten (tcm .table_key ) == table_id
177
+ idx_ch = ak .flatten (tcm .row_in_table )[chan_tcm_indexs ].to_numpy ()
184
178
185
179
if ch not in channels_skip :
186
180
res = utils .get_data_at_channel (
@@ -195,42 +189,37 @@ def evaluate_to_scalar(
195
189
if out is None :
196
190
# define dimension of output array
197
191
out = utils .make_numpy_full (n_rows , default_value , res .dtype )
198
- else :
199
- res = np .full (len (idx_ch ), default_value )
200
-
201
- # get mask from query
202
- limarr = utils .get_mask_from_query (
203
- datainfo = datainfo ,
204
- query = query ,
205
- length = len (res ),
206
- ch = ch ,
207
- idx_ch = idx_ch ,
208
- )
209
-
210
- evt_ids_ch = np .searchsorted (
211
- tcm .cumulative_length ,
212
- np .where (tcm .id == table_id )[0 ],
213
- side = "right" ,
214
- )
215
-
216
- # switch through modes
217
- if "sum" == mode :
218
- if res .dtype == bool :
219
- res = res .astype (int )
220
192
221
- out [evt_ids_ch ] = np .where (limarr , res + out [evt_ids_ch ], out [evt_ids_ch ])
193
+ # get mask from query
194
+ limarr = utils .get_mask_from_query (
195
+ datainfo = datainfo ,
196
+ query = query ,
197
+ length = len (res ),
198
+ ch = ch ,
199
+ idx_ch = idx_ch ,
200
+ )
222
201
223
- if "any" == mode :
224
- if res .dtype != bool :
225
- res = res .astype (bool )
202
+ evt_ids_ch = np .repeat (
203
+ np .arange (0 , len (tcm .table_key )),
204
+ ak .sum (tcm .table_key == table_id , axis = 1 ),
205
+ )
226
206
227
- out [evt_ids_ch ] = out [evt_ids_ch ] | (res & limarr )
207
+ # switch through modes
208
+ if mode == "sum" :
209
+ if res .dtype == bool :
210
+ res = res .astype (int )
211
+ if out .dtype == bool :
212
+ out = out .astype (int )
213
+ out [evt_ids_ch [limarr ]] += res [limarr ]
214
+ else :
215
+ if res .dtype != bool :
216
+ res = res .astype (bool )
228
217
229
- if "all" == mode :
230
- if res .dtype != bool :
231
- res = res .astype (bool )
218
+ if mode == "any" :
219
+ out [evt_ids_ch ] |= res & limarr
232
220
233
- out [evt_ids_ch ] = out [evt_ids_ch ] & res & limarr
221
+ if mode == "all" :
222
+ out [evt_ids_ch ] &= res & limarr
234
223
235
224
return types .Array (nda = out )
236
225
@@ -274,16 +263,20 @@ def evaluate_at_channel(
274
263
275
264
out = None
276
265
277
- for ch in np .unique (ch_comp .nda .astype (int )):
278
- table_name = utils .get_table_name_by_pattern (table_id_fmt , ch )
266
+ for table_id in np .unique (ch_comp .nda .astype (int )):
267
+ table_name = utils .get_table_name_by_pattern (table_id_fmt , table_id )
279
268
# skip default value
280
269
if table_name not in lh5 .ls (f .hit .file ):
281
270
continue
282
271
283
- idx_ch = tcm .idx [tcm .id == ch ]
284
- evt_ids_ch = np .searchsorted (
285
- tcm .cumulative_length , np .where (tcm .id == ch )[0 ], "right"
272
+ # get index list for this channel to be loaded
273
+ chan_tcm_indexs = ak .flatten (tcm .table_key ) == table_id
274
+ idx_ch = ak .flatten (tcm .row_in_table )[chan_tcm_indexs ].to_numpy ()
275
+
276
+ evt_ids_ch = np .repeat (
277
+ np .arange (0 , len (tcm .table_key )), ak .sum (tcm .table_key == table_id , axis = 1 )
286
278
)
279
+
287
280
if (table_name in channels ) and (table_name not in channels_skip ):
288
281
res = utils .get_data_at_channel (
289
282
datainfo = datainfo ,
@@ -299,7 +292,9 @@ def evaluate_at_channel(
299
292
if out is None :
300
293
out = utils .make_numpy_full (len (ch_comp .nda ), default_value , res .dtype )
301
294
302
- out [evt_ids_ch ] = np .where (ch == ch_comp .nda [idx_ch ], res , out [evt_ids_ch ])
295
+ out [evt_ids_ch ] = np .where (
296
+ table_id == ch_comp .nda [idx_ch ], res , out [evt_ids_ch ]
297
+ )
303
298
304
299
return types .Array (nda = out )
305
300
@@ -348,10 +343,10 @@ def evaluate_at_channel_vov(
348
343
)
349
344
350
345
type_name = None
351
- for ch in ch_comp_channels :
352
- table_name = utils .get_table_name_by_pattern (f .hit .table_fmt , ch )
353
- evt_ids_ch = np .searchsorted (
354
- tcm .cumulative_length , np . where (tcm .id == ch )[ 0 ], "right"
346
+ for table_id in ch_comp_channels :
347
+ table_name = utils .get_table_name_by_pattern (f .hit .table_fmt , table_id )
348
+ evt_ids_ch = np .repeat (
349
+ np . arange ( 0 , len ( tcm .table_key )), ak . sum (tcm .table_key == table_id , axis = 1 )
355
350
)
356
351
if (table_name in channels ) and (table_name not in channels_skip ):
357
352
res = utils .get_data_at_channel (
@@ -362,20 +357,19 @@ def evaluate_at_channel_vov(
362
357
field_list = field_list ,
363
358
pars_dict = pars_dict ,
364
359
)
365
- new_evt_ids_ch = np .searchsorted (
366
- ch_comp .cumulative_length ,
367
- np .where (ch_comp .flattened_data .nda == ch )[0 ],
368
- "right" ,
360
+ new_evt_ids_ch = np .repeat (
361
+ np .arange (0 , len (ch_comp )),
362
+ ak .sum (ch_comp .view_as ("ak" ) == table_id , axis = 1 ),
369
363
)
370
364
matches = np .isin (evt_ids_ch , new_evt_ids_ch )
371
- out [ch_comp .flattened_data .nda == ch ] = res [matches ]
365
+ out [ch_comp .flattened_data .nda == table_id ] = res [matches ]
372
366
373
367
else :
374
- length = len (np .where (ch_comp .flattened_data .nda == ch )[0 ])
368
+ length = len (np .where (ch_comp .flattened_data .nda == table_id )[0 ])
375
369
res = np .full (length , default_value )
376
- out [ch_comp .flattened_data .nda == ch ] = res
370
+ out [ch_comp .flattened_data .nda == table_id ] = res
377
371
378
- if ch == ch_comp_channels [0 ]:
372
+ if table_id == ch_comp_channels [0 ]:
379
373
out = out .astype (res .dtype )
380
374
type_name = res .dtype
381
375
@@ -438,12 +432,13 @@ def evaluate_to_aoesa(
438
432
439
433
for i , ch in enumerate (channels ):
440
434
table_id = utils .get_tcm_id_by_pattern (f .hit .table_fmt , ch )
441
- idx_ch = tcm .idx [tcm .id == table_id ]
442
435
443
- evt_ids_ch = np .searchsorted (
444
- tcm .cumulative_length ,
445
- np .where (tcm .id == table_id )[0 ],
446
- "right" ,
436
+ # get index list for this channel to be loaded
437
+ chan_tcm_indexs = ak .flatten (tcm .table_key ) == table_id
438
+ idx_ch = ak .flatten (tcm .row_in_table )[chan_tcm_indexs ].to_numpy ()
439
+
440
+ evt_ids_ch = np .repeat (
441
+ np .arange (0 , len (tcm .table_key )), ak .sum (tcm .table_key == table_id , axis = 1 )
447
442
)
448
443
449
444
if ch not in channels_skip :
0 commit comments