Skip to content

Commit c5eeaa7

Browse files
committed
More general python codegen
1 parent 5be2246 commit c5eeaa7

File tree

1 file changed

+55
-44
lines changed

1 file changed

+55
-44
lines changed

code_generator_python/python_code_generator/templates/transform_single_file.py

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,58 +22,69 @@ def transform_single_file(file_path: str, output_path: Path, output_format: str)
2222
try:
2323
stime = time.time()
2424

25-
output = generated_transformer.run_query(file_path)
26-
27-
ttime = time.time()
28-
29-
if output_format == 'root-file':
25+
# We first see if the function has the signature to directly write output
26+
# If it doesn't, then we assume it's giving us back awkward array results
27+
try:
28+
generated_transformer.run_query(file_path, output_path)
29+
if not output_path.exists():
30+
raise RuntimeError("Transformation did not produce expected output file "
31+
f"{output_path}")
32+
ttime = time.time()
3033
etime = time.time()
31-
if isinstance(output, ak.Array):
32-
awkward_arrays = {default_tree_name: output}
33-
elif isinstance(output, dict):
34-
awkward_arrays = output
35-
with open(output_path, 'b+w') as wfile:
36-
with uproot.recreate(wfile) as writer:
37-
for key in awkward_arrays.keys():
38-
total_events = awkward_arrays[key].__len__()
39-
if awkward_arrays[key].fields and total_events:
40-
o_dict = {field: awkward_arrays[key][field]
41-
for field in awkward_arrays[key].fields}
42-
elif awkward_arrays[key].fields and not total_events:
43-
o_dict = {field: np.array([])
44-
for field in awkward_arrays[key].fields}
45-
elif not awkward_arrays[key].fields and total_events:
46-
o_dict = {default_branch_name: awkward_arrays[key]}
47-
else:
48-
o_dict = {default_branch_name: np.array([])}
49-
writer[key] = o_dict
50-
5134
wtime = time.time()
52-
elif output_format == 'raw-file':
53-
etime = time.time()
5435
total_events = 0
55-
output_path = output
56-
wtime = time.time()
57-
else:
58-
if isinstance(output, dict):
59-
tree_name = list(output.keys())[0]
60-
awkward_array = output[tree_name]
61-
print(f'Returned type from your Python function is a dictionary - '
62-
f'Only the first key {tree_name} will be written as parquet files. '
63-
f'Please use root-file output to write all trees.')
36+
except AttributeError:
37+
output = generated_transformer.run_query(file_path)
38+
39+
ttime = time.time()
40+
if output_format == 'root-file':
41+
etime = time.time()
42+
if isinstance(output, ak.Array):
43+
awkward_arrays = {default_tree_name: output}
44+
elif isinstance(output, dict):
45+
awkward_arrays = output
46+
with open(output_path, 'b+w') as wfile:
47+
with uproot.recreate(wfile) as writer:
48+
for key in awkward_arrays.keys():
49+
total_events = awkward_arrays[key].__len__()
50+
if awkward_arrays[key].fields and total_events:
51+
o_dict = {field: awkward_arrays[key][field]
52+
for field in awkward_arrays[key].fields}
53+
elif awkward_arrays[key].fields and not total_events:
54+
o_dict = {field: np.array([])
55+
for field in awkward_arrays[key].fields}
56+
elif not awkward_arrays[key].fields and total_events:
57+
o_dict = {default_branch_name: awkward_arrays[key]}
58+
else:
59+
o_dict = {default_branch_name: np.array([])}
60+
writer[key] = o_dict
61+
62+
wtime = time.time()
63+
elif output_format == 'raw-file':
64+
etime = time.time()
65+
total_events = 0
66+
output_path = output
67+
wtime = time.time()
6468
else:
65-
awkward_array = output
69+
if isinstance(output, dict):
70+
tree_name = list(output.keys())[0]
71+
awkward_array = output[tree_name]
72+
print(f'Returned type from your Python function is a dictionary - '
73+
f'Only the first key {tree_name} will be written as parquet files. '
74+
f'Please use root-file output to write all trees.')
75+
else:
76+
awkward_array = output
6677

67-
total_events = ak.num(awkward_array, axis=0)
68-
arrow = ak.to_arrow_table(awkward_array)
78+
total_events = ak.num(awkward_array, axis=0)
79+
arrow = ak.to_arrow_table(awkward_array)
6980

70-
etime = time.time()
81+
etime = time.time()
7182

72-
writer = pq.ParquetWriter(output_path, arrow.schema)
73-
writer.write_table(table=arrow)
74-
writer.close()
83+
writer = pq.ParquetWriter(output_path, arrow.schema)
84+
writer.write_table(table=arrow)
85+
writer.close()
7586

76-
wtime = time.time()
87+
wtime = time.time()
7788

7889
output_size = os.stat(output_path).st_size
7990
print(f'Detailed transformer times. query_time:{round(ttime - stime, 3)} '

0 commit comments

Comments
 (0)