Skip to content

Commit db969e6

Browse files
authored
Merge pull request #101 from fastmachinelearning/inference_cost_breakdown
inference cost breakdown
2 parents c5bd87f + a4e7e35 commit db969e6

File tree

8 files changed

+246
-103
lines changed

8 files changed

+246
-103
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ Inference cost for CNV_2W2A.onnx
101101
}
102102
```
103103

104+
You can use the `--cost-breakdown` option to generate a more detailed report that covers per-node (by name) and per-op-type information.
104105
You can read more about the BOPS metric in [this paper](https://www.frontiersin.org/articles/10.3389/frai.2021.676564/full), Section 4.2 Bit Operations.
105106

106107
### Convert between different quantization representations

src/qonnx/analysis/inference_cost.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def inference_cost_conv(model, node, discount_sparsity):
117117
mac_op_type_str = "op_mac_%s_%s" % (idt_name, wdt_name)
118118
w_mem_type_str = "mem_w_%s" % (wdt_name)
119119
o_mem_type_str = "mem_o_%s" % (odt_name)
120+
# keep in floats to remain compatible with json serialization
121+
n_macs, w_mem, o_mem = float(n_macs), float(w_mem), float(o_mem)
120122
ret = {mac_op_type_str: n_macs, w_mem_type_str: w_mem, o_mem_type_str: o_mem}
121123
return ret
122124

@@ -161,6 +163,8 @@ def inference_cost_matmul(model, node, discount_sparsity):
161163
mac_op_type_str = "op_mac_%s_%s" % (idt_name, wdt_name)
162164
w_mem_type_str = "mem_w_%s" % (wdt_name)
163165
o_mem_type_str = "mem_o_%s" % (odt_name)
166+
# keep in floats to remain compatible with json serialization
167+
n_macs, w_mem, o_mem = float(n_macs), float(w_mem), float(o_mem)
164168
ret = {mac_op_type_str: n_macs, w_mem_type_str: w_mem, o_mem_type_str: o_mem}
165169
return ret
166170

@@ -197,14 +201,16 @@ def inference_cost_upsample(model, node, discount_sparsity):
197201
mac_op_type_str = "op_mac_%s_%s" % (idt_name, idt_name)
198202
o_mem_type_str = "mem_o_%s" % (odt_name)
199203

204+
# keep in floats to remain compatible with json serialization
205+
n_macs, o_mem = float(n_macs), float(o_mem)
200206
ret = {mac_op_type_str: n_macs, o_mem_type_str: o_mem}
201207
return ret
202208

203209

204-
def inference_cost(model, discount_sparsity=True):
210+
def inference_cost(model, discount_sparsity=True, cost_breakdown=False):
205211
"Ensure all nodes have unique names prior to calling this analysis pass."
206212

207-
node_costs = {}
213+
ret, node_costs, nodes_per_optype = {}, {}, {}
208214
zero_cost_ops = [
209215
"MaxPool",
210216
"AveragePool",
@@ -240,13 +246,24 @@ def inference_cost(model, discount_sparsity=True):
240246
if node.op_type in inference_cost_fxn_map.keys():
241247
node_cost = inference_cost_fxn_map[node.op_type](model, node, discount_sparsity)
242248
node_costs[node.name] = node_cost
249+
if node.op_type not in nodes_per_optype.keys():
250+
new_optype = {}
251+
new_optype[node.name] = node_cost
252+
nodes_per_optype[node.op_type] = new_optype
253+
else:
254+
nodes_per_optype[node.op_type][node.name] = node_cost
243255
elif node.op_type in zero_cost_ops:
244256
continue
245257
else:
246258
unsupported_ops.add(node.op_type)
247-
248-
ret = aggregate_dict_keys(node_costs)
249-
ret["unsupported"] = unsupported_ops
250-
ret["discount_sparsity"] = discount_sparsity
251-
259+
total = aggregate_dict_keys(node_costs)
260+
total["unsupported"] = unsupported_ops
261+
total["discount_sparsity"] = discount_sparsity
262+
ret["total_cost"] = total
263+
if cost_breakdown:
264+
optype_cost = {}
265+
for optype, resources in nodes_per_optype.items():
266+
optype_cost[optype] = aggregate_dict_keys(resources)
267+
ret["optype_cost"] = optype_cost
268+
ret["node_cost"] = node_costs
252269
return ret

src/qonnx/util/inference_cost.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,24 @@ def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"):
7070
return total_mem_bits, total_mem_elems
7171

7272

73+
def assign_mem_bits_and_elems(res_dict):
74+
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(res_dict, "mem_w")
75+
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(res_dict, "mem_o")
76+
res_dict["total_mem_w_bits"] = mem_w_bits
77+
res_dict["total_mem_w_elems"] = mem_w_elems
78+
res_dict["total_mem_o_bits"] = mem_o_bits
79+
res_dict["total_mem_o_elems"] = mem_o_elems
80+
return res_dict
81+
82+
7383
def inference_cost(
74-
model_filename_or_wrapper, *, output_json=None, output_onnx=None, preprocess=True, discount_sparsity=True
84+
model_filename_or_wrapper,
85+
*,
86+
output_json=None,
87+
output_onnx=None,
88+
preprocess=True,
89+
discount_sparsity=True,
90+
cost_breakdown=False
7591
):
7692
"""Return the inference cost estimate metric for given ONNX model.
7793
Supports the Quant op for weight/activation quantization.
@@ -84,7 +100,10 @@ def inference_cost(
84100
datatype inference and constant folding. Strongly recommended.
85101
:param discount_sparsity: If set, will discount op cost of MAC ops with a
86102
constant zero weight, and the mem cost of constant zero weights.
87-
"""
103+
:param cost_breakdown: If set, include per-node (by name) and per-node-type
104+
breakdowns as part of the returned inference cost dict."""
105+
106+
combined_results = {}
88107
if isinstance(model_filename_or_wrapper, ModelWrapper):
89108
model = model_filename_or_wrapper
90109
else:
@@ -104,25 +123,29 @@ def inference_cost(
104123
model = model.transform(GiveReadableTensorNames())
105124
if output_onnx is not None:
106125
model.save(output_onnx)
107-
ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity))
108-
bops, macs = compute_bops_and_macs(ret)
109-
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(ret, "mem_w")
110-
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(ret, "mem_o")
111-
ret["total_bops"] = bops
112-
ret["total_macs"] = macs
113-
ret["total_mem_w_bits"] = mem_w_bits
114-
ret["total_mem_w_elems"] = mem_w_elems
115-
ret["total_mem_o_bits"] = mem_o_bits
116-
ret["total_mem_o_elems"] = mem_o_elems
117-
118-
if "unsupported" in ret:
119-
ret["unsupported"] = str(ret["unsupported"])
120-
126+
ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity, cost_breakdown))
127+
for i, res in ret.items():
128+
if i == "total_cost":
129+
bops, macs = compute_bops_and_macs(res)
130+
res = assign_mem_bits_and_elems(res)
131+
res["total_bops"] = bops
132+
res["total_macs"] = macs
133+
if "unsupported" in res:
134+
res["unsupported"] = str(res["unsupported"])
135+
combined_results[i] = res
136+
elif i in ["optype_cost", "node_cost"]:
137+
per_optype_or_node_breakdown = {}
138+
for optype, op_res in res.items():
139+
bops, macs = compute_bops_and_macs(op_res)
140+
op_res = assign_mem_bits_and_elems(op_res)
141+
op_res["total_bops"] = bops
142+
op_res["total_macs"] = macs
143+
per_optype_or_node_breakdown[optype] = op_res
144+
combined_results[i] = per_optype_or_node_breakdown
121145
if output_json is not None:
122146
with open(output_json, "w") as f:
123-
json.dump(ret, f, sort_keys=True, indent=2)
124-
125-
return ret
147+
json.dump(combined_results, f, sort_keys=True, indent=2)
148+
return combined_results
126149

127150

128151
def main():

tests/analysis/test_inference_cost.py

Lines changed: 82 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -34,90 +34,102 @@
3434
model_details_infcost = {
3535
"FINN-CNV_W2A2": {
3636
"expected_sparse": {
37-
"op_mac_SCALEDINT<8>_INT2": 1345500.0,
38-
"mem_w_INT2": 908033.0,
39-
"mem_o_SCALEDINT<32>": 57600.0,
40-
"op_mac_INT2_INT2": 35615771.0,
41-
"mem_o_INT32": 85002.0,
42-
"unsupported": "set()",
43-
"discount_sparsity": True,
44-
"total_bops": 163991084.0,
45-
"total_macs": 36961271.0,
46-
"total_mem_w_bits": 1816066.0,
47-
"total_mem_w_elems": 908033.0,
48-
"total_mem_o_bits": 4563264.0,
49-
"total_mem_o_elems": 142602.0,
37+
"total_cost": {
38+
"op_mac_SCALEDINT<8>_INT2": 1345500.0,
39+
"mem_w_INT2": 908033.0,
40+
"mem_o_SCALEDINT<32>": 57600.0,
41+
"op_mac_INT2_INT2": 35615771.0,
42+
"mem_o_INT32": 85002.0,
43+
"unsupported": "set()",
44+
"discount_sparsity": True,
45+
"total_bops": 163991084.0,
46+
"total_macs": 36961271.0,
47+
"total_mem_w_bits": 1816066.0,
48+
"total_mem_w_elems": 908033.0,
49+
"total_mem_o_bits": 4563264.0,
50+
"total_mem_o_elems": 142602.0,
51+
}
5052
},
5153
"expected_dense": {
52-
"op_mac_SCALEDINT<8>_INT2": 1555200.0,
53-
"mem_w_INT2": 1542848.0,
54-
"mem_o_SCALEDINT<32>": 57600.0,
55-
"op_mac_INT2_INT2": 57906176.0,
56-
"mem_o_INT32": 85002.0,
57-
"unsupported": "set()",
58-
"discount_sparsity": False,
59-
"total_bops": 256507904.0,
60-
"total_macs": 59461376.0,
61-
"total_mem_w_bits": 3085696.0,
62-
"total_mem_w_elems": 1542848.0,
63-
"total_mem_o_bits": 4563264.0,
64-
"total_mem_o_elems": 142602.0,
54+
"total_cost": {
55+
"op_mac_SCALEDINT<8>_INT2": 1555200.0,
56+
"mem_w_INT2": 1542848.0,
57+
"mem_o_SCALEDINT<32>": 57600.0,
58+
"op_mac_INT2_INT2": 57906176.0,
59+
"mem_o_INT32": 85002.0,
60+
"unsupported": "set()",
61+
"discount_sparsity": False,
62+
"total_bops": 256507904.0,
63+
"total_macs": 59461376.0,
64+
"total_mem_w_bits": 3085696.0,
65+
"total_mem_w_elems": 1542848.0,
66+
"total_mem_o_bits": 4563264.0,
67+
"total_mem_o_elems": 142602.0,
68+
}
6569
},
6670
},
6771
"FINN-TFC_W2A2": {
6872
"expected_sparse": {
69-
"op_mac_INT2_INT2": 22355.0,
70-
"mem_w_INT2": 22355.0,
71-
"mem_o_INT32": 202.0,
72-
"unsupported": "set()",
73-
"discount_sparsity": True,
74-
"total_bops": 89420.0,
75-
"total_macs": 22355.0,
76-
"total_mem_w_bits": 44710.0,
77-
"total_mem_w_elems": 22355.0,
78-
"total_mem_o_bits": 6464.0,
79-
"total_mem_o_elems": 202.0,
73+
"total_cost": {
74+
"op_mac_INT2_INT2": 22355.0,
75+
"mem_w_INT2": 22355.0,
76+
"mem_o_INT32": 202.0,
77+
"unsupported": "set()",
78+
"discount_sparsity": True,
79+
"total_bops": 89420.0,
80+
"total_macs": 22355.0,
81+
"total_mem_w_bits": 44710.0,
82+
"total_mem_w_elems": 22355.0,
83+
"total_mem_o_bits": 6464.0,
84+
"total_mem_o_elems": 202.0,
85+
}
8086
},
8187
"expected_dense": {
82-
"op_mac_INT2_INT2": 59008.0,
83-
"mem_w_INT2": 59008.0,
84-
"mem_o_INT32": 202.0,
85-
"unsupported": "set()",
86-
"discount_sparsity": False,
87-
"total_bops": 236032.0,
88-
"total_macs": 59008.0,
89-
"total_mem_w_bits": 118016.0,
90-
"total_mem_w_elems": 59008.0,
91-
"total_mem_o_bits": 6464.0,
92-
"total_mem_o_elems": 202.0,
88+
"total_cost": {
89+
"op_mac_INT2_INT2": 59008.0,
90+
"mem_w_INT2": 59008.0,
91+
"mem_o_INT32": 202.0,
92+
"unsupported": "set()",
93+
"discount_sparsity": False,
94+
"total_bops": 236032.0,
95+
"total_macs": 59008.0,
96+
"total_mem_w_bits": 118016.0,
97+
"total_mem_w_elems": 59008.0,
98+
"total_mem_o_bits": 6464.0,
99+
"total_mem_o_elems": 202.0,
100+
}
93101
},
94102
},
95103
"RadioML_VGG10": {
96104
"expected_sparse": {
97-
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12620311.0,
98-
"mem_w_SCALEDINT<8>": 155617.0,
99-
"mem_o_SCALEDINT<32>": 130328.0,
100-
"unsupported": "set()",
101-
"discount_sparsity": True,
102-
"total_bops": 807699904.0,
103-
"total_macs": 12620311.0,
104-
"total_mem_w_bits": 1244936.0,
105-
"total_mem_w_elems": 155617.0,
106-
"total_mem_o_bits": 4170496.0,
107-
"total_mem_o_elems": 130328.0,
105+
"total_cost": {
106+
"unsupported": "set()",
107+
"discount_sparsity": True,
108+
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12620311.0,
109+
"mem_w_SCALEDINT<8>": 155617.0,
110+
"mem_o_SCALEDINT<32>": 130328.0,
111+
"total_bops": 807699904.0,
112+
"total_macs": 12620311.0,
113+
"total_mem_w_bits": 1244936.0,
114+
"total_mem_w_elems": 155617.0,
115+
"total_mem_o_bits": 4170496.0,
116+
"total_mem_o_elems": 130328.0,
117+
}
108118
},
109119
"expected_dense": {
110-
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12864512.0,
111-
"mem_w_SCALEDINT<8>": 159104.0,
112-
"mem_o_SCALEDINT<32>": 130328.0,
113-
"unsupported": "set()",
114-
"discount_sparsity": False,
115-
"total_bops": 823328768.0,
116-
"total_macs": 12864512.0,
117-
"total_mem_w_bits": 1272832.0,
118-
"total_mem_w_elems": 159104.0,
119-
"total_mem_o_bits": 4170496.0,
120-
"total_mem_o_elems": 130328.0,
120+
"total_cost": {
121+
"unsupported": "set()",
122+
"discount_sparsity": False,
123+
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12864512.0,
124+
"mem_w_SCALEDINT<8>": 159104.0,
125+
"mem_o_SCALEDINT<32>": 130328.0,
126+
"total_bops": 823328768.0,
127+
"total_macs": 12864512.0,
128+
"total_mem_w_bits": 1272832.0,
129+
"total_mem_w_elems": 159104.0,
130+
"total_mem_o_bits": 4170496.0,
131+
"total_mem_o_elems": 130328.0,
132+
}
121133
},
122134
},
123135
}

0 commit comments

Comments
 (0)