44
44
from qonnx .transformation .infer_datatypes import InferDataTypes
45
45
from qonnx .transformation .infer_shapes import InferShapes
46
46
47
+
47
48
def compute_bops_and_macs (inf_cost_dict ):
48
49
total_bops = 0.0
49
50
total_macs = 0.0
@@ -56,6 +57,7 @@ def compute_bops_and_macs(inf_cost_dict):
56
57
total_macs += v
57
58
return total_bops , total_macs
58
59
60
+
59
61
def compute_mem_bits_and_elems (inf_cost_dict , filter_string = "mem_w" ):
60
62
total_mem_bits = 0.0
61
63
total_mem_elems = 0.0
@@ -67,6 +69,7 @@ def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"):
67
69
total_mem_elems += v
68
70
return total_mem_bits , total_mem_elems
69
71
72
+
70
73
def assign_mem_bits_and_elems (res_dict ):
71
74
mem_w_bits , mem_w_elems = compute_mem_bits_and_elems (res_dict , "mem_w" )
72
75
mem_o_bits , mem_o_elems = compute_mem_bits_and_elems (res_dict , "mem_o" )
@@ -76,6 +79,7 @@ def assign_mem_bits_and_elems(res_dict):
76
79
res_dict ["total_mem_o_elems" ] = mem_o_elems
77
80
return res_dict
78
81
82
+
79
83
def inference_cost (
80
84
model_filename_or_wrapper ,
81
85
* ,
@@ -96,7 +100,7 @@ def inference_cost(
96
100
datatype inference and constant folding. Strongly recommended.
97
101
:param discount_sparsity: If set, will discount op cost of MAC ops with a
98
102
constant zero weight, and the mem cost of constant zero weights."""
99
-
103
+
100
104
combined_results = {}
101
105
if isinstance (model_filename_or_wrapper , ModelWrapper ):
102
106
model = model_filename_or_wrapper
@@ -117,8 +121,7 @@ def inference_cost(
117
121
model = model .transform (GiveReadableTensorNames ())
118
122
if output_onnx is not None :
119
123
model .save (output_onnx )
120
- ret = model .analysis (lambda x : infca .inference_cost (x , discount_sparsity ,
121
- cost_breakdown ))
124
+ ret = model .analysis (lambda x : infca .inference_cost (x , discount_sparsity , cost_breakdown ))
122
125
for i , res in ret .items ():
123
126
if i == "total_cost" :
124
127
bops , macs = compute_bops_and_macs (res )
@@ -148,9 +151,11 @@ def inference_cost(
148
151
per_node_breakdown [node_name ] = node_res
149
152
combined_results [i ] = per_node_breakdown
150
153
return combined_results
151
-
154
+
155
+
152
156
def main ():
153
157
clize .run (inference_cost )
154
158
159
+
155
160
if __name__ == "__main__" :
156
161
main ()
0 commit comments