Skip to content

Commit a4e7e35

Browse files
committed
[Test] fix changes return style for inference cost
1 parent 0ca12ce commit a4e7e35

File tree

5 files changed

+88
-76
lines changed

5 files changed

+88
-76
lines changed

src/qonnx/util/inference_cost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def inference_cost(
133133
if "unsupported" in res:
134134
res["unsupported"] = str(res["unsupported"])
135135
combined_results[i] = res
136-
else:
136+
elif i in ["optype_cost", "node_cost"]:
137137
per_optype_or_node_breakdown = {}
138138
for optype, op_res in res.items():
139139
bops, macs = compute_bops_and_macs(op_res)

tests/analysis/test_inference_cost.py

Lines changed: 82 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -34,90 +34,102 @@
3434
model_details_infcost = {
3535
"FINN-CNV_W2A2": {
3636
"expected_sparse": {
37-
"op_mac_SCALEDINT<8>_INT2": 1345500.0,
38-
"mem_w_INT2": 908033.0,
39-
"mem_o_SCALEDINT<32>": 57600.0,
40-
"op_mac_INT2_INT2": 35615771.0,
41-
"mem_o_INT32": 85002.0,
42-
"unsupported": "set()",
43-
"discount_sparsity": True,
44-
"total_bops": 163991084.0,
45-
"total_macs": 36961271.0,
46-
"total_mem_w_bits": 1816066.0,
47-
"total_mem_w_elems": 908033.0,
48-
"total_mem_o_bits": 4563264.0,
49-
"total_mem_o_elems": 142602.0,
37+
"total_cost": {
38+
"op_mac_SCALEDINT<8>_INT2": 1345500.0,
39+
"mem_w_INT2": 908033.0,
40+
"mem_o_SCALEDINT<32>": 57600.0,
41+
"op_mac_INT2_INT2": 35615771.0,
42+
"mem_o_INT32": 85002.0,
43+
"unsupported": "set()",
44+
"discount_sparsity": True,
45+
"total_bops": 163991084.0,
46+
"total_macs": 36961271.0,
47+
"total_mem_w_bits": 1816066.0,
48+
"total_mem_w_elems": 908033.0,
49+
"total_mem_o_bits": 4563264.0,
50+
"total_mem_o_elems": 142602.0,
51+
}
5052
},
5153
"expected_dense": {
52-
"op_mac_SCALEDINT<8>_INT2": 1555200.0,
53-
"mem_w_INT2": 1542848.0,
54-
"mem_o_SCALEDINT<32>": 57600.0,
55-
"op_mac_INT2_INT2": 57906176.0,
56-
"mem_o_INT32": 85002.0,
57-
"unsupported": "set()",
58-
"discount_sparsity": False,
59-
"total_bops": 256507904.0,
60-
"total_macs": 59461376.0,
61-
"total_mem_w_bits": 3085696.0,
62-
"total_mem_w_elems": 1542848.0,
63-
"total_mem_o_bits": 4563264.0,
64-
"total_mem_o_elems": 142602.0,
54+
"total_cost": {
55+
"op_mac_SCALEDINT<8>_INT2": 1555200.0,
56+
"mem_w_INT2": 1542848.0,
57+
"mem_o_SCALEDINT<32>": 57600.0,
58+
"op_mac_INT2_INT2": 57906176.0,
59+
"mem_o_INT32": 85002.0,
60+
"unsupported": "set()",
61+
"discount_sparsity": False,
62+
"total_bops": 256507904.0,
63+
"total_macs": 59461376.0,
64+
"total_mem_w_bits": 3085696.0,
65+
"total_mem_w_elems": 1542848.0,
66+
"total_mem_o_bits": 4563264.0,
67+
"total_mem_o_elems": 142602.0,
68+
}
6569
},
6670
},
6771
"FINN-TFC_W2A2": {
6872
"expected_sparse": {
69-
"op_mac_INT2_INT2": 22355.0,
70-
"mem_w_INT2": 22355.0,
71-
"mem_o_INT32": 202.0,
72-
"unsupported": "set()",
73-
"discount_sparsity": True,
74-
"total_bops": 89420.0,
75-
"total_macs": 22355.0,
76-
"total_mem_w_bits": 44710.0,
77-
"total_mem_w_elems": 22355.0,
78-
"total_mem_o_bits": 6464.0,
79-
"total_mem_o_elems": 202.0,
73+
"total_cost": {
74+
"op_mac_INT2_INT2": 22355.0,
75+
"mem_w_INT2": 22355.0,
76+
"mem_o_INT32": 202.0,
77+
"unsupported": "set()",
78+
"discount_sparsity": True,
79+
"total_bops": 89420.0,
80+
"total_macs": 22355.0,
81+
"total_mem_w_bits": 44710.0,
82+
"total_mem_w_elems": 22355.0,
83+
"total_mem_o_bits": 6464.0,
84+
"total_mem_o_elems": 202.0,
85+
}
8086
},
8187
"expected_dense": {
82-
"op_mac_INT2_INT2": 59008.0,
83-
"mem_w_INT2": 59008.0,
84-
"mem_o_INT32": 202.0,
85-
"unsupported": "set()",
86-
"discount_sparsity": False,
87-
"total_bops": 236032.0,
88-
"total_macs": 59008.0,
89-
"total_mem_w_bits": 118016.0,
90-
"total_mem_w_elems": 59008.0,
91-
"total_mem_o_bits": 6464.0,
92-
"total_mem_o_elems": 202.0,
88+
"total_cost": {
89+
"op_mac_INT2_INT2": 59008.0,
90+
"mem_w_INT2": 59008.0,
91+
"mem_o_INT32": 202.0,
92+
"unsupported": "set()",
93+
"discount_sparsity": False,
94+
"total_bops": 236032.0,
95+
"total_macs": 59008.0,
96+
"total_mem_w_bits": 118016.0,
97+
"total_mem_w_elems": 59008.0,
98+
"total_mem_o_bits": 6464.0,
99+
"total_mem_o_elems": 202.0,
100+
}
93101
},
94102
},
95103
"RadioML_VGG10": {
96104
"expected_sparse": {
97-
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12620311.0,
98-
"mem_w_SCALEDINT<8>": 155617.0,
99-
"mem_o_SCALEDINT<32>": 130328.0,
100-
"unsupported": "set()",
101-
"discount_sparsity": True,
102-
"total_bops": 807699904.0,
103-
"total_macs": 12620311.0,
104-
"total_mem_w_bits": 1244936.0,
105-
"total_mem_w_elems": 155617.0,
106-
"total_mem_o_bits": 4170496.0,
107-
"total_mem_o_elems": 130328.0,
105+
"total_cost": {
106+
"unsupported": "set()",
107+
"discount_sparsity": True,
108+
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12620311.0,
109+
"mem_w_SCALEDINT<8>": 155617.0,
110+
"mem_o_SCALEDINT<32>": 130328.0,
111+
"total_bops": 807699904.0,
112+
"total_macs": 12620311.0,
113+
"total_mem_w_bits": 1244936.0,
114+
"total_mem_w_elems": 155617.0,
115+
"total_mem_o_bits": 4170496.0,
116+
"total_mem_o_elems": 130328.0,
117+
}
108118
},
109119
"expected_dense": {
110-
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12864512.0,
111-
"mem_w_SCALEDINT<8>": 159104.0,
112-
"mem_o_SCALEDINT<32>": 130328.0,
113-
"unsupported": "set()",
114-
"discount_sparsity": False,
115-
"total_bops": 823328768.0,
116-
"total_macs": 12864512.0,
117-
"total_mem_w_bits": 1272832.0,
118-
"total_mem_w_elems": 159104.0,
119-
"total_mem_o_bits": 4170496.0,
120-
"total_mem_o_elems": 130328.0,
120+
"total_cost": {
121+
"unsupported": "set()",
122+
"discount_sparsity": False,
123+
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12864512.0,
124+
"mem_w_SCALEDINT<8>": 159104.0,
125+
"mem_o_SCALEDINT<32>": 130328.0,
126+
"total_bops": 823328768.0,
127+
"total_macs": 12864512.0,
128+
"total_mem_w_bits": 1272832.0,
129+
"total_mem_w_elems": 159104.0,
130+
"total_mem_o_bits": 4170496.0,
131+
"total_mem_o_elems": 130328.0,
132+
}
121133
},
122134
},
123135
}

tests/analysis/test_matmul_mac_cost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,6 @@ def test_matmul_mac_cost():
4040
cleaned_model = cleanup_model(model)
4141
# Two Matmul layers with shape (i_shape, w_shape, o_shape),
4242
# L1: ([4, 64, 32], [4, 32, 64], [4, 64, 64]) and L2: ([4, 64, 64], [4, 64, 32], [4, 64, 32])
43-
inf_cost_dict = infc.inference_cost(cleaned_model, discount_sparsity=False)
43+
inf_cost_dict = infc.inference_cost(cleaned_model, discount_sparsity=False)["total_cost"]
4444
mac_cost = inf_cost_dict["op_mac_FLOAT32_FLOAT32"] # Expected mac cost 4*32*64*64 + 4*64*64*32 = 1048576
4545
assert mac_cost == 1048576.0, "Error: discrepancy in mac cost."

tests/transformation/test_pruning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_pruning_mnv1():
9090
# do cleanup including folding quantized weights
9191
model = cleanup_model(model, False)
9292
inp, golden = get_golden_in_and_output("MobileNetv1-w4a4")
93-
cost0 = inference_cost(model, discount_sparsity=False)
93+
cost0 = inference_cost(model, discount_sparsity=False)["total_cost"]
9494
assert cost0["op_mac_SCALEDINT<8>_SCALEDINT<8>"] == 10645344.0
9595
assert cost0["mem_w_SCALEDINT<8>"] == 864.0
9696
assert cost0["op_mac_SCALEDINT<4>_SCALEDINT<4>"] == 556357408.0
@@ -105,7 +105,7 @@ def test_pruning_mnv1():
105105
}
106106

107107
model = model.transform(PruneChannels(prune_spec))
108-
cost1 = inference_cost(model, discount_sparsity=False)
108+
cost1 = inference_cost(model, discount_sparsity=False)["total_cost"]
109109
assert cost1["op_mac_SCALEDINT<8>_SCALEDINT<8>"] == 7318674.0
110110
assert cost1["mem_w_SCALEDINT<8>"] == 594.0
111111
assert cost1["op_mac_SCALEDINT<4>_SCALEDINT<4>"] == 546053216.0

tests/transformation/test_quantize_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ def to_verify(model, test_details):
120120
def test_quantize_graph(test_model):
121121
test_details = model_details[test_model]
122122
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
123-
original_model_inf_cost = inference_cost(model, discount_sparsity=False)
123+
original_model_inf_cost = inference_cost(model, discount_sparsity=False)["total_cost"]
124124
nodes_pos = test_details["test_input"]
125125
model = model.transform(QuantizeGraph(nodes_pos))
126126
quantnodes_added = len(model.get_nodes_by_op_type("Quant"))
127127
assert quantnodes_added == 10 # 10 positions are specified.
128128
verification = to_verify(model, nodes_pos)
129129
assert verification == "Success"
130-
inf_cost = inference_cost(model, discount_sparsity=False)
130+
inf_cost = inference_cost(model, discount_sparsity=False)["total_cost"]
131131
assert (
132132
inf_cost["total_macs"] == original_model_inf_cost["total_macs"]
133133
) # "1814073344.0" must be same as the original model.

0 commit comments

Comments
 (0)