@@ -71,7 +71,13 @@ def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"):
71
71
72
72
73
73
def inference_cost (
74
- model_filename_or_wrapper , * , output_json = None , output_onnx = None , preprocess = True , discount_sparsity = True
74
+ model_filename_or_wrapper ,
75
+ * ,
76
+ output_json = None ,
77
+ output_onnx = None ,
78
+ preprocess = True ,
79
+ discount_sparsity = True ,
80
+ cost_breakdown = False
75
81
):
76
82
"""Return the inference cost estimate metric for given ONNX model.
77
83
Supports the Quant op for weight/activation quantization.
@@ -83,8 +89,8 @@ def inference_cost(
83
89
:param preprocess: If set, run preprocessing steps such as shape inference,
84
90
datatype inference and constant folding. Strongly recommended.
85
91
:param discount_sparsity: If set, will discount op cost of MAC ops with a
86
- constant zero weight, and the mem cost of constant zero weights.
87
- """
92
+ constant zero weight, and the mem cost of constant zero weights."""
93
+ combined_results = {}
88
94
if isinstance (model_filename_or_wrapper , ModelWrapper ):
89
95
model = model_filename_or_wrapper
90
96
else :
@@ -104,25 +110,51 @@ def inference_cost(
104
110
model = model .transform (GiveReadableTensorNames ())
105
111
if output_onnx is not None :
106
112
model .save (output_onnx )
107
- ret = model .analysis (lambda x : infca .inference_cost (x , discount_sparsity ))
108
- bops , macs = compute_bops_and_macs (ret )
109
- mem_w_bits , mem_w_elems = compute_mem_bits_and_elems (ret , "mem_w" )
110
- mem_o_bits , mem_o_elems = compute_mem_bits_and_elems (ret , "mem_o" )
111
- ret ["total_bops" ] = bops
112
- ret ["total_macs" ] = macs
113
- ret ["total_mem_w_bits" ] = mem_w_bits
114
- ret ["total_mem_w_elems" ] = mem_w_elems
115
- ret ["total_mem_o_bits" ] = mem_o_bits
116
- ret ["total_mem_o_elems" ] = mem_o_elems
117
-
118
- if "unsupported" in ret :
119
- ret ["unsupported" ] = str (ret ["unsupported" ])
120
-
121
- if output_json is not None :
122
- with open (output_json , "w" ) as f :
123
- json .dump (ret , f , sort_keys = True , indent = 2 )
124
-
125
- return ret
113
+ ret = model .analysis (lambda x : infca .inference_cost (x , discount_sparsity , cost_breakdown ))
114
+ for i , res in ret .items ():
115
+ if i == "total_cost" :
116
+ bops , macs = compute_bops_and_macs (res )
117
+ mem_w_bits , mem_w_elems = compute_mem_bits_and_elems (res , "mem_w" )
118
+ mem_o_bits , mem_o_elems = compute_mem_bits_and_elems (res , "mem_o" )
119
+ res ["total_bops" ] = bops
120
+ res ["total_macs" ] = macs
121
+ res ["total_mem_w_bits" ] = mem_w_bits
122
+ res ["total_mem_w_elems" ] = mem_w_elems
123
+ res ["total_mem_o_bits" ] = mem_o_bits
124
+ res ["total_mem_o_elems" ] = mem_o_elems
125
+ if "unsupported" in res :
126
+ res ["unsupported" ] = str (res ["unsupported" ])
127
+ if output_json is not None :
128
+ with open (output_json , "w" ) as f :
129
+ json .dump (res , f , sort_keys = True , indent = 2 )
130
+ combined_results [i ] = res
131
+ elif i == "optype_cost" :
132
+ per_optype_breakdown = {}
133
+ for optype , op_res in res .items ():
134
+ bops , macs = compute_bops_and_macs (op_res )
135
+ mem_w_bits , mem_w_elems = compute_mem_bits_and_elems (op_res , "mem_w" )
136
+ mem_o_bits , mem_o_elems = compute_mem_bits_and_elems (op_res , "mem_o" )
137
+ op_res ["total_bops" ] = bops
138
+ op_res ["total_macs" ] = macs
139
+ op_res ["total_mem_w_bits" ] = mem_w_bits
140
+ op_res ["total_mem_w_elems" ] = mem_w_elems
141
+ op_res ["total_mem_o_bits" ] = mem_o_bits
142
+ op_res ["total_mem_o_elems" ] = mem_o_elems
143
+ per_optype_breakdown [optype ] = op_res
144
+ combined_results [i ] = per_optype_breakdown
145
+ else :
146
+ per_node_breakdown = {}
147
+ for node_name in res .keys ():
148
+ node_cost = res [node_name ]
149
+ mem_w_bits , mem_w_elems = compute_mem_bits_and_elems (node_cost , "mem_w" )
150
+ mem_o_bits , mem_o_elems = compute_mem_bits_and_elems (node_cost , "mem_o" )
151
+ node_cost ["total_mem_w_bits" ] = mem_w_bits
152
+ node_cost ["total_mem_w_elems" ] = mem_w_elems
153
+ node_cost ["total_mem_o_bits" ] = mem_o_bits
154
+ node_cost ["total_mem_o_elems" ] = mem_o_elems
155
+ per_node_breakdown [node_name ] = node_cost
156
+ combined_results [i ] = per_node_breakdown
157
+ return combined_results
126
158
127
159
128
160
def main ():
0 commit comments