Skip to content

Commit a772619

Browse files
iguinniguinnpre-commit-ci[bot]
authored
Updated pargen.utils.load_data to use LH5Iterator and field_mask to be more memory efficient (#589)
Updated pargen.utils.load_data to use LH5Iterator and field_mask to be more memory efficient --------- Co-authored-by: iguinn <iguinn@email.unc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3338331 commit a772619

File tree

1 file changed

+84
-70
lines changed

1 file changed

+84
-70
lines changed

src/pygama/pargen/utils.py

Lines changed: 84 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from lgdo import lh5
1010

1111
log = logging.getLogger(__name__)
12-
sto = lh5.LH5Store()
1312

1413

1514
def convert_to_minuit(pars, func):
@@ -35,101 +34,116 @@ def return_nans(input):
3534
return m.values, m.errors, np.full((len(m.values), len(m.values)), np.nan)
3635

3736

38-
def get_params(file_params, param_list):
39-
out_params = []
40-
if isinstance(file_params, dict):
41-
possible_keys = file_params.keys()
42-
elif isinstance(file_params, list):
43-
possible_keys = file_params
44-
for param in param_list:
45-
for key in possible_keys:
46-
if key in param:
47-
out_params.append(key)
48-
return np.unique(out_params).tolist()
49-
50-
5137
def load_data(
52-
files: list,
38+
files: str | list | dict,
5339
lh5_path: str,
5440
cal_dict: dict,
55-
params: list,
41+
params: set,
5642
cal_energy_param: str = "cuspEmax_ctc_cal",
5743
threshold=None,
5844
return_selection_mask=False,
59-
) -> tuple(np.array, np.array, np.array, np.array):
45+
) -> pd.DataFrame | tuple(pd.DataFrame, np.array):
6046
"""
61-
Loads in the A/E parameters needed and applies calibration constants to energy
47+
Loads parameters from data files. Applies calibration to cal_energy_param
48+
and uses this to apply a lower energy threshold.
49+
50+
files
51+
file or list of files or dict pointing from timestamps to lists of files
52+
lh5_path
53+
path to table in files
54+
cal_dict
55+
dictionary with operations used to apply calibration constants
56+
params
57+
list of parameters to load from file
58+
cal_energy_param
59+
name of uncalibrated energy parameter
60+
threshold
61+
lower energy threshold for events to load
62+
return_selection_map
63+
if True, return selection mask for threshold along with data
6264
"""
6365

66+
params = set(params)
6467
if isinstance(files, str):
6568
files = [files]
6669

6770
if isinstance(files, dict):
68-
keys = lh5.ls(
69-
files[list(files)[0]][0],
70-
lh5_path if lh5_path[-1] == "/" else lh5_path + "/",
71-
)
72-
keys = [key.split("/")[-1] for key in keys]
73-
if list(files)[0] in cal_dict:
74-
params = get_params(keys + list(cal_dict[list(files)[0]].keys()), params)
75-
else:
76-
params = get_params(keys + list(cal_dict.keys()), params)
77-
71+
# Go through each tstamp and recursively load_data on file lists
7872
df = []
79-
all_files = []
80-
masks = np.array([], dtype=bool)
73+
masks = []
8174
for tstamp, tfiles in files.items():
82-
table = sto.read(lh5_path, tfiles)[0]
83-
84-
file_df = pd.DataFrame(columns=params)
85-
if tstamp in cal_dict:
86-
cal_dict_ts = cal_dict[tstamp]
75+
file_df = load_data(
76+
tfiles,
77+
lh5_path,
78+
cal_dict.get(tstamp, cal_dict),
79+
params,
80+
cal_energy_param,
81+
threshold,
82+
return_selection_mask,
83+
)
84+
85+
if return_selection_mask:
86+
file_df[0]["run_timestamp"] = np.full(
87+
len(file_df[0]), tstamp, dtype=object
88+
)
89+
df.append(file_df[0])
90+
masks.append(file_df[1])
8791
else:
88-
cal_dict_ts = cal_dict
92+
file_df["run_timestamp"] = np.full(len(file_df), tstamp, dtype=object)
93+
df.append(file_df)
8994

90-
for outname, info in cal_dict_ts.items():
91-
outcol = table.eval(info["expression"], info.get("parameters", None))
92-
table.add_column(outname, outcol)
93-
94-
for param in params:
95-
file_df[param] = table[param]
95+
df = pd.concat(df)
96+
if return_selection_mask:
97+
masks = np.concatenate(masks)
9698

97-
file_df["run_timestamp"] = np.full(len(file_df), tstamp, dtype=object)
99+
elif isinstance(files, list):
100+
# Get set of available fields between input table and cal_dict
101+
file_keys = lh5.ls(
102+
files[0], lh5_path if lh5_path[-1] == "/" else lh5_path + "/"
103+
)
104+
file_keys = {key.split("/")[-1] for key in file_keys}
98105

99-
if threshold is not None:
100-
mask = file_df[cal_energy_param] > threshold
101-
file_df.drop(np.where(~mask)[0], inplace=True)
102-
else:
103-
mask = np.ones(len(file_df), dtype=bool)
104-
masks = np.append(masks, mask)
105-
df.append(file_df)
106-
all_files += tfiles
106+
# Get set of keys in calibration expressions that show up in file
107+
cal_keys = {
108+
name
109+
for info in cal_dict.values()
110+
for name in compile(info["expression"], "0vbb is real!", "eval").co_names
111+
} & file_keys
107112

108-
params.append("run_timestamp")
109-
df = pd.concat(df)
113+
# Get set of fields to read from files
114+
fields = cal_keys | (file_keys & params)
110115

111-
elif isinstance(files, list):
112-
keys = lh5.ls(files[0], lh5_path if lh5_path[-1] == "/" else lh5_path + "/")
113-
keys = [key.split("/")[-1] for key in keys]
114-
params = get_params(keys + list(cal_dict.keys()), params)
115-
116-
table = sto.read(lh5_path, files)[0]
117-
df = pd.DataFrame(columns=params)
118-
for outname, info in cal_dict.items():
119-
outcol = table.eval(info["expression"], info.get("parameters", None))
120-
table.add_column(outname, outcol)
121-
for param in params:
122-
df[param] = table[param]
116+
lh5_it = lh5.iterator.LH5Iterator(
117+
files, lh5_path, field_mask=fields, buffer_len=100000
118+
)
119+
df_fields = params & (fields | set(cal_dict))
120+
if df_fields != params:
121+
log.debug(
122+
f"load_data(): params not found in data files or cal_dict: {params-df_fields}"
123+
)
124+
df = pd.DataFrame(columns=list(df_fields))
125+
126+
for table, entry, n_rows in lh5_it:
127+
# Evaluate all provided expressions and add to table
128+
for outname, info in cal_dict.items():
129+
table[outname] = table.eval(
130+
info["expression"], info.get("parameters", None)
131+
)
132+
133+
# Copy params in table into dataframe
134+
for par in df:
135+
# First set of entries: allocate enough memory for all entries
136+
if entry == 0:
137+
df[par] = np.resize(table[par], len(lh5_it))
138+
else:
139+
df.loc[entry : entry + n_rows - 1, par] = table[par][:n_rows]
140+
141+
# Evaluate threshold mask and drop events below threshold
123142
if threshold is not None:
124143
masks = df[cal_energy_param] > threshold
125144
df.drop(np.where(~masks)[0], inplace=True)
126145
else:
127146
masks = np.ones(len(df), dtype=bool)
128-
all_files = files
129-
130-
for col in list(df.keys()):
131-
if col not in params:
132-
df.drop(col, inplace=True, axis=1)
133147

134148
log.debug("data loaded")
135149
if return_selection_mask:

0 commit comments

Comments
 (0)