Skip to content

Commit 3e132fe

Browse files
committed
[GraphQnt] some cleanup and renaming
1 parent 2feab84 commit 3e132fe

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

src/qonnx/transformation/introduce_quantnode.py renamed to src/qonnx/transformation/quantize_graph.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -180,32 +180,30 @@ def adjust_graph(self, model, input_positions, node_in_focus, quantized_nodes, n
180180
return model
181181

182182

183-
class IntroduceQuantnode(Transformation):
184-
"""This transformation can be used to introduce a Quant node for a specific type of node in the graph.
185-
Users would be able to specify the location of the quant node by providing the input and output indexs
186-
as the parameters.
183+
class QuantizeGraph(Transformation):
184+
"""This transformation can be used to introduce a Quant node for particular nodes in the graph,
185+
determined based on either op_type or node name.
186+
For the particular nodes identified, users can specify the location of the Quant nodes by providing
187+
the input and output indices where Quant nodes are to be inserted.
188+
Assumes the input model is cleaned-up with all intermediate shapes specified and nodes given
189+
unique names already.
187190
188-
1) Expectations:
189-
a) Onnx model in the modelwraper format.
190-
b) Model must be cleaned using cleanup_model qonnx.util.cleanup.cleanup_model()
191-
c) Batchsize to be set.
191+
2) Steps to transform are
192+
Step1: Finding the input for the quant node.
193+
Step2: Finding the consumer of the quant node output.
194+
Step3: Finding the shape for the output tensor of quant node.
195+
Note: The output tensor of the quant node must have the same shape as the
196+
consumer of the input to the quant node.
192197
193-
2) Steps to transform are
194-
Step1: Finding the input for the quant node.
195-
Step2: Finding the consumer of the quant node output.
196-
Step3: Finding the shape for the output tensor of quant node.
197-
Note: The output tensor of the quant node must have the same shape as the
198-
consumer of the input to the quant node.
198+
3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type".
199199
200-
3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type".
200+
4) Assert:
201+
a) The input is a dictionary representing the node names as keys and a list of quant positions
202+
as values.
203+
b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation.
201204
202-
4) Assert:
203-
a) The input is a dictionary representing the node names as keys and a list of quant positions
204-
as values.
205-
b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation.
206-
207-
5) Return:
208-
Returns a cleaned version of the model.
205+
5) Return:
206+
Returns a cleaned version of the model.
209207
210208
"""
211209

tests/transformation/test_introduce_quantnode.py renamed to tests/transformation/test_quantize_graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,21 @@
3333
import urllib.request
3434

3535
from qonnx.core.modelwrapper import ModelWrapper
36-
from qonnx.transformation.introduce_quantnode import IntroduceQuantnode, graph_util
36+
from qonnx.transformation.quantize_graph import QuantizeGraph, graph_util
3737
from qonnx.util.cleanup import cleanup
3838
from qonnx.util.inference_cost import inference_cost
3939

4040
random.seed(42)
4141

4242
graph_util = graph_util()
4343

44-
a = "https://github.com/onnx/models/raw/main/validated/vision/"
45-
b = "classification/resnet/model/resnet18-v1-7.onnx?download="
44+
download_url = "https://github.com/onnx/models/raw/main/validated/vision/"
45+
download_url += "classification/resnet/model/resnet18-v1-7.onnx?download="
4646

4747
model_details = {
4848
"resnet18-v1-7": {
4949
"description": "Resnet18 Opset version 7.",
50-
"url": (a + b),
50+
"url": download_url,
5151
"test_input": {
5252
"name": {
5353
"Conv_0": [
@@ -124,7 +124,7 @@ def test_introduce_quantnode(test_model):
124124
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
125125
original_model_inf_cost = inference_cost(model, discount_sparsity=False)
126126
nodes_pos = test_details["test_input"]
127-
model = model.transform(IntroduceQuantnode(nodes_pos))
127+
model = model.transform(QuantizeGraph(nodes_pos))
128128
quantnodes_added = len(model.get_nodes_by_op_type("Quant"))
129129
assert quantnodes_added == 10 # 10 positions are specified.
130130
verification = to_verify(model, nodes_pos)

0 commit comments

Comments
 (0)