Skip to content

Commit e116fc2

Browse files
author
klassen9
committed
Add resize to change_3d_tensors_to_4d transformation
1 parent db969e6 commit e116fc2

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

src/qonnx/transformation/change_3d_tensors_to_4d.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def _find_invalid_nodes(model):
6464
"Reshape",
6565
"MaxPool",
6666
"Upsample",
67+
"Resize",
6768
]
6869
invalid_nodes = []
6970
for n in model.graph.node:
@@ -194,6 +195,42 @@ def apply(self, model):
194195
assert list(scales.shape) == [3]
195196
scales = np.append(scales, np.asarray(1.0, dtype=np.float32))
196197
model.set_initializer(n.input[1], scales)
198+
elif node_op_type == "Resize":
199+
if len(n.input) == 2:
200+
# Resize version 10
201+
scales = model.get_initializer(n.input[1])
202+
scales = np.append(scales, np.asarray(1.0, dtype=np.float32))
203+
model.set_initializer(n.input[1], scales)
204+
elif len(n.input) == 3:
205+
# Resize version 11 and up (no size input)
206+
scales = model.get_initializer(n.input[2])
207+
scales = np.append(scales, np.asarray(1.0, dtype=np.float32))
208+
model.set_initializer(n.input[2], scales)
209+
elif len(n.input) == 4:
210+
scales_exists = (model.get_initializer(n.input[2]) is not None) and (len(model.get_initializer(n.input[2])) != 0)
211+
sizes_exists = (model.get_initializer(n.input[3]) is not None) and (len(model.get_initializer(n.input[3])) != 0)
212+
assert (scales_exists ^ sizes_exists), (
213+
"%s: Either scales or the target output size must "
214+
"be specified. Specifying both is prohibited." % n.name
215+
)
216+
if (scales_exists):
217+
# Scales parameter is a 1d list of upsampling factors along each axis
218+
scales = model.get_initializer(n.input[2])
219+
scales = np.append(scales, np.asarray(1.0, dtype=np.float32))
220+
model.set_initializer(n.input[2], scales)
221+
else:
222+
# Size parameter is a 1d list of the target size along each axis
223+
sizes = model.get_initializer(n.input[3])
224+
sizes = np.append(sizes, np.asarray(1.0, dtype=np.int64))
225+
model.set_initializer(n.input[3], sizes)
226+
if len(n.input) in (3, 4) and model.get_initializer(n.input[1]) is not None:
227+
# ROI handling
228+
roi = model.get_initializer(n.input[1])
229+
d_type = roi.dtype #float64, float32 or float16
230+
# ROI for 3d tensor: [start1, start2, start3, end1, end2, end3]
231+
roi = np.concatenate((roi[0:3], np.asarray(1.0, dtype=d_type), roi[3:6], np.asarray(1.0, dtype=d_type)), axis=None)
232+
model.set_initializer(n.input[1], roi)
233+
input_shape.append(1)
197234

198235
# Change format of each input/value_info/output tensor
199236
for k, v in all_tensors.items():

0 commit comments

Comments
 (0)