@@ -69,6 +69,14 @@ def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"):
69
69
total_mem_elems += v
70
70
return total_mem_bits , total_mem_elems
71
71
72
+ def assign_mem_bits_and_elems (res_dict ):
73
+ mem_w_bits , mem_w_elems = compute_mem_bits_and_elems (res_dict , "mem_w" )
74
+ mem_o_bits , mem_o_elems = compute_mem_bits_and_elems (res_dict , "mem_o" )
75
+ res_dict ["total_mem_w_bits" ] = mem_w_bits
76
+ res_dict ["total_mem_w_elems" ] = mem_w_elems
77
+ res_dict ["total_mem_o_bits" ] = mem_o_bits
78
+ res_dict ["total_mem_o_elems" ] = mem_o_elems
79
+ return res_dict
72
80
73
81
def inference_cost (
74
82
model_filename_or_wrapper ,
@@ -114,14 +122,9 @@ def inference_cost(
114
122
for i , res in ret .items ():
115
123
if i == "total_cost" :
116
124
bops , macs = compute_bops_and_macs (res )
117
- mem_w_bits , mem_w_elems = compute_mem_bits_and_elems (res , "mem_w" )
118
- mem_o_bits , mem_o_elems = compute_mem_bits_and_elems (res , "mem_o" )
125
+ res = assign_mem_bits_and_elems (res )
119
126
res ["total_bops" ] = bops
120
127
res ["total_macs" ] = macs
121
- res ["total_mem_w_bits" ] = mem_w_bits
122
- res ["total_mem_w_elems" ] = mem_w_elems
123
- res ["total_mem_o_bits" ] = mem_o_bits
124
- res ["total_mem_o_elems" ] = mem_o_elems
125
128
if "unsupported" in res :
126
129
res ["unsupported" ] = str (res ["unsupported" ])
127
130
if output_json is not None :
@@ -132,31 +135,20 @@ def inference_cost(
132
135
per_optype_breakdown = {}
133
136
for optype , op_res in res .items ():
134
137
bops , macs = compute_bops_and_macs (op_res )
135
- mem_w_bits , mem_w_elems = compute_mem_bits_and_elems (op_res , "mem_w" )
136
- mem_o_bits , mem_o_elems = compute_mem_bits_and_elems (op_res , "mem_o" )
138
+ op_res = assign_mem_bits_and_elems (op_res )
137
139
op_res ["total_bops" ] = bops
138
140
op_res ["total_macs" ] = macs
139
- op_res ["total_mem_w_bits" ] = mem_w_bits
140
- op_res ["total_mem_w_elems" ] = mem_w_elems
141
- op_res ["total_mem_o_bits" ] = mem_o_bits
142
- op_res ["total_mem_o_elems" ] = mem_o_elems
143
141
per_optype_breakdown [optype ] = op_res
144
142
combined_results [i ] = per_optype_breakdown
145
143
else :
146
144
per_node_breakdown = {}
147
145
for node_name in res .keys ():
148
- node_cost = res [node_name ]
149
- mem_w_bits , mem_w_elems = compute_mem_bits_and_elems (node_cost , "mem_w" )
150
- mem_o_bits , mem_o_elems = compute_mem_bits_and_elems (node_cost , "mem_o" )
151
- node_cost ["total_mem_w_bits" ] = mem_w_bits
152
- node_cost ["total_mem_w_elems" ] = mem_w_elems
153
- node_cost ["total_mem_o_bits" ] = mem_o_bits
154
- node_cost ["total_mem_o_elems" ] = mem_o_elems
155
- per_node_breakdown [node_name ] = node_cost
146
+ node_res = res [node_name ]
147
+ node_res = assign_mem_bits_and_elems (node_res )
148
+ per_node_breakdown [node_name ] = node_res
156
149
combined_results [i ] = per_node_breakdown
157
150
return combined_results
158
151
159
-
160
152
def main ():
161
153
clize .run (inference_cost )
162
154
0 commit comments