Skip to content

Commit d120742

Browse files
authored
Update inference_cost.py
1 parent 7608e7c commit d120742

File tree

1 file changed

+13
-21
lines changed

1 file changed

+13
-21
lines changed

src/qonnx/util/inference_cost.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"):
6969
total_mem_elems += v
7070
return total_mem_bits, total_mem_elems
7171

72+
def assign_mem_bits_and_elems(res_dict):
73+
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(res_dict, "mem_w")
74+
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(res_dict, "mem_o")
75+
res_dict["total_mem_w_bits"] = mem_w_bits
76+
res_dict["total_mem_w_elems"] = mem_w_elems
77+
res_dict["total_mem_o_bits"] = mem_o_bits
78+
res_dict["total_mem_o_elems"] = mem_o_elems
79+
return res_dict
7280

7381
def inference_cost(
7482
model_filename_or_wrapper,
@@ -114,14 +122,9 @@ def inference_cost(
114122
for i, res in ret.items():
115123
if i == "total_cost":
116124
bops, macs = compute_bops_and_macs(res)
117-
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(res, "mem_w")
118-
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(res, "mem_o")
125+
res = assign_mem_bits_and_elems(res)
119126
res["total_bops"] = bops
120127
res["total_macs"] = macs
121-
res["total_mem_w_bits"] = mem_w_bits
122-
res["total_mem_w_elems"] = mem_w_elems
123-
res["total_mem_o_bits"] = mem_o_bits
124-
res["total_mem_o_elems"] = mem_o_elems
125128
if "unsupported" in res:
126129
res["unsupported"] = str(res["unsupported"])
127130
if output_json is not None:
@@ -132,31 +135,20 @@ def inference_cost(
132135
per_optype_breakdown = {}
133136
for optype, op_res in res.items():
134137
bops, macs = compute_bops_and_macs(op_res)
135-
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(op_res, "mem_w")
136-
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(op_res, "mem_o")
138+
op_res = assign_mem_bits_and_elems(op_res)
137139
op_res["total_bops"] = bops
138140
op_res["total_macs"] = macs
139-
op_res["total_mem_w_bits"] = mem_w_bits
140-
op_res["total_mem_w_elems"] = mem_w_elems
141-
op_res["total_mem_o_bits"] = mem_o_bits
142-
op_res["total_mem_o_elems"] = mem_o_elems
143141
per_optype_breakdown[optype] = op_res
144142
combined_results[i] = per_optype_breakdown
145143
else:
146144
per_node_breakdown = {}
147145
for node_name in res.keys():
148-
node_cost = res[node_name]
149-
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(node_cost, "mem_w")
150-
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(node_cost, "mem_o")
151-
node_cost["total_mem_w_bits"] = mem_w_bits
152-
node_cost["total_mem_w_elems"] = mem_w_elems
153-
node_cost["total_mem_o_bits"] = mem_o_bits
154-
node_cost["total_mem_o_elems"] = mem_o_elems
155-
per_node_breakdown[node_name] = node_cost
146+
node_res = res[node_name]
147+
node_res = assign_mem_bits_and_elems(node_res)
148+
per_node_breakdown[node_name] = node_res
156149
combined_results[i] = per_node_breakdown
157150
return combined_results
158151

159-
160152
def main():
161153
clize.run(inference_cost)
162154

0 commit comments

Comments
 (0)