Skip to content

Commit 1f8938a

Browse files
authored
Merge pull request #142 from fastmachinelearning/feature/test_a2q_nets
Add easy fetch of accumulator-aware quantized (A2Q) CIFAR-10 models for testing
2 parents c9ad9e5 + 8694a6d commit 1f8938a

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

src/qonnx/util/test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,76 @@
3737
# utility functions to fetch models and data for
3838
# testing various qonnx transformations
3939

40+
a2q_rn18_preproc_mean = np.asarray([0.491, 0.482, 0.447], dtype=np.float32)
41+
a2q_rn18_preproc_std = np.asarray([0.247, 0.243, 0.262], dtype=np.float32)
42+
a2q_rn18_int_range = (0, 255)
43+
a2q_rn18_iscale = 1 / 255
44+
a2q_rn18_rmin = (a2q_rn18_int_range[0] * a2q_rn18_iscale - a2q_rn18_preproc_mean) / a2q_rn18_preproc_std
45+
a2q_rn18_rmax = (a2q_rn18_int_range[1] * a2q_rn18_iscale - a2q_rn18_preproc_mean) / a2q_rn18_preproc_std
46+
a2q_rn18_scale = (1 / a2q_rn18_preproc_std) * a2q_rn18_iscale
47+
a2q_rn18_bias = -a2q_rn18_preproc_mean * a2q_rn18_preproc_std
48+
a2q_rn18_common = {
49+
"input_shape": (1, 3, 32, 32),
50+
"input_range": (a2q_rn18_rmin, a2q_rn18_rmax),
51+
"int_range": a2q_rn18_int_range,
52+
"scale": a2q_rn18_scale,
53+
"bias": a2q_rn18_bias,
54+
}
55+
a2q_rn18_urlbase = "https://github.com/fastmachinelearning/qonnx_model_zoo/releases/download/a2q-20240905/"
56+
57+
a2q_model_details = {
58+
"rn18_w4a4_a2q_16b": {
59+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 16-bit accumulators",
60+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_16b-d4bfa990.onnx",
61+
**a2q_rn18_common,
62+
},
63+
"rn18_w4a4_a2q_15b": {
64+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 15-bit accumulators",
65+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_15b-eeca8ac2.onnx",
66+
**a2q_rn18_common,
67+
},
68+
"rn18_w4a4_a2q_14b": {
69+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 14-bit accumulators",
70+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_14b-563cf426.onnx",
71+
**a2q_rn18_common,
72+
},
73+
"rn18_w4a4_a2q_13b": {
74+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 13-bit accumulators",
75+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_13b-d3cae293.onnx",
76+
**a2q_rn18_common,
77+
},
78+
"rn18_w4a4_a2q_12b": {
79+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 12-bit accumulators",
80+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_12b-fb3a0f8a.onnx",
81+
**a2q_rn18_common,
82+
},
83+
"rn18_w4a4_a2q_plus_16b": {
84+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 16-bit accumulators",
85+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_16b-09e47feb.onnx",
86+
**a2q_rn18_common,
87+
},
88+
"rn18_w4a4_a2q_plus_15b": {
89+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 15-bit accumulators",
90+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_15b-10e7bc83.onnx",
91+
**a2q_rn18_common,
92+
},
93+
"rn18_w4a4_a2q_plus_14b": {
94+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 14-bit accumulators",
95+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_14b-8db8c78c.onnx",
96+
**a2q_rn18_common,
97+
},
98+
"rn18_w4a4_a2q_plus_13b": {
99+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 13-bit accumulators",
100+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_13b-f57b05ce.onnx",
101+
**a2q_rn18_common,
102+
},
103+
"rn18_w4a4_a2q_plus_12b": {
104+
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 12-bit accumulators",
105+
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_12b-1e2aca29.onnx",
106+
**a2q_rn18_common,
107+
},
108+
}
109+
40110
test_model_details = {
41111
"FINN-CNV_W2A2": {
42112
"description": "2-bit VGG-10-like CNN on CIFAR-10",
@@ -116,6 +186,7 @@
116186
"input_shape": (1, 3, 224, 224),
117187
"input_range": (0, 1),
118188
},
189+
**a2q_model_details,
119190
}
120191

121192

@@ -149,6 +220,11 @@ def get_random_input(test_model, seed=42):
149220
rng = np.random.RandomState(seed)
150221
input_shape = test_model_details[test_model]["input_shape"]
151222
(low, high) = test_model_details[test_model]["input_range"]
223+
# some models spec per-channel ranges, be conservative for those
224+
if isinstance(low, np.ndarray):
225+
low = low.max()
226+
if isinstance(high, np.ndarray):
227+
high = high.min()
152228
size = np.prod(np.asarray(input_shape))
153229
input_tensor = rng.uniform(low=low, high=high, size=size)
154230
input_tensor = input_tensor.astype(np.float32)

tests/transformation/test_change_batchsize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def test_change_batchsize(test_model):
4545
batch_size = 10
4646
old_ishape = test_details["input_shape"]
4747
imin, imax = test_details["input_range"]
48+
# some models spec per-channel ranges, be conservative for those
49+
if isinstance(imin, np.ndarray):
50+
imin = imin.max()
51+
if isinstance(imax, np.ndarray):
52+
imax = imax.min()
4853
model = download_model(test_model=test_model, do_cleanup=True, return_modelwrapper=True)
4954
iname = model.graph.input[0].name
5055
oname = model.graph.output[0].name

0 commit comments

Comments
 (0)