Skip to content

Commit be9a9f8

Browse files
authored
Update inference_cost.py
1 parent d120742 commit be9a9f8

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

src/qonnx/util/inference_cost.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from qonnx.transformation.infer_datatypes import InferDataTypes
4545
from qonnx.transformation.infer_shapes import InferShapes
4646

47-
4847
def compute_bops_and_macs(inf_cost_dict):
4948
total_bops = 0.0
5049
total_macs = 0.0
@@ -57,7 +56,6 @@ def compute_bops_and_macs(inf_cost_dict):
5756
total_macs += v
5857
return total_bops, total_macs
5958

60-
6159
def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"):
6260
total_mem_bits = 0.0
6361
total_mem_elems = 0.0
@@ -98,6 +96,7 @@ def inference_cost(
9896
datatype inference and constant folding. Strongly recommended.
9997
:param discount_sparsity: If set, will discount op cost of MAC ops with a
10098
constant zero weight, and the mem cost of constant zero weights."""
99+
101100
combined_results = {}
102101
if isinstance(model_filename_or_wrapper, ModelWrapper):
103102
model = model_filename_or_wrapper
@@ -118,7 +117,8 @@ def inference_cost(
118117
model = model.transform(GiveReadableTensorNames())
119118
if output_onnx is not None:
120119
model.save(output_onnx)
121-
ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity, cost_breakdown))
120+
ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity,
121+
cost_breakdown))
122122
for i, res in ret.items():
123123
if i == "total_cost":
124124
bops, macs = compute_bops_and_macs(res)
@@ -148,10 +148,9 @@ def inference_cost(
148148
per_node_breakdown[node_name] = node_res
149149
combined_results[i] = per_node_breakdown
150150
return combined_results
151-
151+
152152
def main():
153153
clize.run(inference_cost)
154154

155-
156155
if __name__ == "__main__":
157156
main()

0 commit comments

Comments
 (0)