Skip to content

Commit 7608e7c

Browse files
author
Harish
committed
inference cost breakdown
1 parent 39442cb commit 7608e7c

File tree

3 files changed

+160
-29
lines changed

3 files changed

+160
-29
lines changed

src/qonnx/analysis/inference_cost.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ def inference_cost_upsample(model, node, discount_sparsity):
201201
return ret
202202

203203

204-
def inference_cost(model, discount_sparsity=True):
204+
def inference_cost(model, discount_sparsity=True, cost_breakdown=False):
205205
"Ensure all nodes have unique names prior to calling this analysis pass."
206206

207-
node_costs = {}
207+
ret, node_costs, nodes_per_optype = {}, {}, {}
208208
zero_cost_ops = [
209209
"MaxPool",
210210
"AveragePool",
@@ -240,13 +240,24 @@ def inference_cost(model, discount_sparsity=True):
240240
if node.op_type in inference_cost_fxn_map.keys():
241241
node_cost = inference_cost_fxn_map[node.op_type](model, node, discount_sparsity)
242242
node_costs[node.name] = node_cost
243+
if node.op_type not in nodes_per_optype.keys():
244+
new_optype = {}
245+
new_optype[node.name] = node_cost
246+
nodes_per_optype[node.op_type] = new_optype
247+
else:
248+
nodes_per_optype[node.op_type][node.name] = node_cost
243249
elif node.op_type in zero_cost_ops:
244250
continue
245251
else:
246252
unsupported_ops.add(node.op_type)
247-
248-
ret = aggregate_dict_keys(node_costs)
249-
ret["unsupported"] = unsupported_ops
250-
ret["discount_sparsity"] = discount_sparsity
251-
253+
total = aggregate_dict_keys(node_costs)
254+
total["unsupported"] = unsupported_ops
255+
total["discount_sparsity"] = discount_sparsity
256+
ret["total_cost"] = total
257+
if cost_breakdown:
258+
optype_cost = {}
259+
for optype, resources in nodes_per_optype.items():
260+
optype_cost[optype] = aggregate_dict_keys(resources)
261+
ret["optype_cost"] = optype_cost
262+
ret["node_cost"] = node_costs
252263
return ret

src/qonnx/util/inference_cost.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"):
7171

7272

7373
def inference_cost(
74-
model_filename_or_wrapper, *, output_json=None, output_onnx=None, preprocess=True, discount_sparsity=True
74+
model_filename_or_wrapper,
75+
*,
76+
output_json=None,
77+
output_onnx=None,
78+
preprocess=True,
79+
discount_sparsity=True,
80+
cost_breakdown=False
7581
):
7682
"""Return the inference cost estimate metric for given ONNX model.
7783
Supports the Quant op for weight/activation quantization.
@@ -83,8 +89,8 @@ def inference_cost(
8389
:param preprocess: If set, run preprocessing steps such as shape inference,
8490
datatype inference and constant folding. Strongly recommended.
8591
:param discount_sparsity: If set, will discount op cost of MAC ops with a
86-
constant zero weight, and the mem cost of constant zero weights.
87-
"""
92+
constant zero weight, and the mem cost of constant zero weights."""
93+
combined_results = {}
8894
if isinstance(model_filename_or_wrapper, ModelWrapper):
8995
model = model_filename_or_wrapper
9096
else:
@@ -104,25 +110,51 @@ def inference_cost(
104110
model = model.transform(GiveReadableTensorNames())
105111
if output_onnx is not None:
106112
model.save(output_onnx)
107-
ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity))
108-
bops, macs = compute_bops_and_macs(ret)
109-
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(ret, "mem_w")
110-
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(ret, "mem_o")
111-
ret["total_bops"] = bops
112-
ret["total_macs"] = macs
113-
ret["total_mem_w_bits"] = mem_w_bits
114-
ret["total_mem_w_elems"] = mem_w_elems
115-
ret["total_mem_o_bits"] = mem_o_bits
116-
ret["total_mem_o_elems"] = mem_o_elems
117-
118-
if "unsupported" in ret:
119-
ret["unsupported"] = str(ret["unsupported"])
120-
121-
if output_json is not None:
122-
with open(output_json, "w") as f:
123-
json.dump(ret, f, sort_keys=True, indent=2)
124-
125-
return ret
113+
ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity, cost_breakdown))
114+
for i, res in ret.items():
115+
if i == "total_cost":
116+
bops, macs = compute_bops_and_macs(res)
117+
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(res, "mem_w")
118+
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(res, "mem_o")
119+
res["total_bops"] = bops
120+
res["total_macs"] = macs
121+
res["total_mem_w_bits"] = mem_w_bits
122+
res["total_mem_w_elems"] = mem_w_elems
123+
res["total_mem_o_bits"] = mem_o_bits
124+
res["total_mem_o_elems"] = mem_o_elems
125+
if "unsupported" in res:
126+
res["unsupported"] = str(res["unsupported"])
127+
if output_json is not None:
128+
with open(output_json, "w") as f:
129+
json.dump(res, f, sort_keys=True, indent=2)
130+
combined_results[i] = res
131+
elif i == "optype_cost":
132+
per_optype_breakdown = {}
133+
for optype, op_res in res.items():
134+
bops, macs = compute_bops_and_macs(op_res)
135+
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(op_res, "mem_w")
136+
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(op_res, "mem_o")
137+
op_res["total_bops"] = bops
138+
op_res["total_macs"] = macs
139+
op_res["total_mem_w_bits"] = mem_w_bits
140+
op_res["total_mem_w_elems"] = mem_w_elems
141+
op_res["total_mem_o_bits"] = mem_o_bits
142+
op_res["total_mem_o_elems"] = mem_o_elems
143+
per_optype_breakdown[optype] = op_res
144+
combined_results[i] = per_optype_breakdown
145+
else:
146+
per_node_breakdown = {}
147+
for node_name in res.keys():
148+
node_cost = res[node_name]
149+
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(node_cost, "mem_w")
150+
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(node_cost, "mem_o")
151+
node_cost["total_mem_w_bits"] = mem_w_bits
152+
node_cost["total_mem_w_elems"] = mem_w_elems
153+
node_cost["total_mem_o_bits"] = mem_o_bits
154+
node_cost["total_mem_o_elems"] = mem_o_elems
155+
per_node_breakdown[node_name] = node_cost
156+
combined_results[i] = per_node_breakdown
157+
return combined_results
126158

127159

128160
def main():
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) 2024 Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of qonnx nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import pytest
30+
31+
import os
32+
import urllib.request
33+
34+
from qonnx.analysis.inference_cost import aggregate_dict_keys
35+
from qonnx.core.modelwrapper import ModelWrapper
36+
from qonnx.util.cleanup import cleanup
37+
from qonnx.util.inference_cost import inference_cost as infca
38+
39+
download_url = "https://github.com/onnx/models/raw/main/validated/vision/"
40+
download_url += "classification/resnet/model/resnet18-v1-7.onnx?download="
41+
42+
model_details = {
43+
"resnet18-v1-7": {
44+
"description": "Resnet18 Opset version 7.",
45+
"url": download_url,
46+
"enc": {
47+
"a": "op_mac_FLOAT32_FLOAT32",
48+
"b": "total_mem_w_bits",
49+
"c": "total_mem_w_elems",
50+
"d": "total_mem_o_bits",
51+
"e": "total_mem_o_elems",
52+
},
53+
},
54+
}
55+
56+
57+
def download_model(test_model, do_cleanup=False, return_modelwrapper=False):
58+
qonnx_url = model_details[test_model]["url"]
59+
# download test data
60+
dl_dir = "/tmp"
61+
dl_file = dl_dir + f"/{test_model}.onnx"
62+
ret = dl_file
63+
if not os.path.isfile(dl_file):
64+
urllib.request.urlretrieve(qonnx_url, dl_file)
65+
if do_cleanup:
66+
out_file = dl_dir + f"/{test_model}_clean.onnx"
67+
cleanup(dl_file, out_file=out_file, override_inpsize=1)
68+
ret = out_file
69+
if return_modelwrapper:
70+
ret = ModelWrapper(ret)
71+
return ret
72+
73+
74+
@pytest.mark.parametrize("test_model", model_details.keys())
75+
def test_inference_cost_breakdown(test_model):
76+
test_details = model_details[test_model]
77+
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
78+
inf_cost = infca(model, discount_sparsity=False, cost_breakdown=True)
79+
print(inf_cost.keys())
80+
t_cost = inf_cost["total_cost"] # total cost
81+
op_cost = aggregate_dict_keys(inf_cost["optype_cost"]) # cost per optype
82+
n_cost = aggregate_dict_keys(inf_cost["node_cost"]) # cost per node.
83+
enc = test_details["enc"]
84+
assert t_cost[enc["a"]] == op_cost[enc["a"]] == n_cost[enc["a"]], "inf discrepancy"
85+
assert t_cost[enc["b"]] == op_cost[enc["b"]] == n_cost[enc["b"]], "inf discrepancy"
86+
assert t_cost[enc["c"]] == op_cost[enc["c"]] == n_cost[enc["c"]], "inf discrepancy"
87+
assert t_cost[enc["d"]] == op_cost[enc["d"]] == n_cost[enc["d"]], "inf discrepancy"
88+
assert t_cost[enc["e"]] == op_cost[enc["e"]] == n_cost[enc["e"]], "inf discrepancy"

0 commit comments

Comments
 (0)