Skip to content

Commit 6298b26

Browse files
nanoskriptfatcat-z
andauthored
feat: add support for ExtractImagePatches (#2188)
* feat: add support for ExtractImagePatches Signed-off-by: Nanoskript <96655713+nanoskript@users.noreply.github.com> * docs: expand non-empty output constraint for ExtractImagePatches Signed-off-by: Nanoskript <96655713+nanoskript@users.noreply.github.com> --------- Signed-off-by: Nanoskript <96655713+nanoskript@users.noreply.github.com> Co-authored-by: Jay Zhang <36183870+fatcat-z@users.noreply.github.com>
1 parent b7a2953 commit 6298b26

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

tests/test_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
matrix_diag_part = tf.compat.v1.matrix_diag_part
7474
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
7575
fake_quant_with_min_max_vars = tf.quantization.fake_quant_with_min_max_vars
76+
extract_image_patches = tf.image.extract_patches
7677
elif Version(tf.__version__) >= Version("1.13"):
7778
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
7879
conv3d_transpose = tf.compat.v1.nn.conv3d_transpose
@@ -96,6 +97,7 @@
9697
matrix_diag_part = tf.compat.v1.matrix_diag_part
9798
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
9899
fake_quant_with_min_max_vars = tf.compat.v1.quantization.fake_quant_with_min_max_vars
100+
extract_image_patches = tf.compat.v1.extract_image_patches
99101
else:
100102
conv2d_backprop_input = tf.nn.conv2d_backprop_input
101103
conv3d_transpose = tf.nn.conv3d_transpose
@@ -113,6 +115,7 @@
113115
is_inf = tf.is_inf
114116
floormod = tf.floormod
115117
matrix_diag_part = tf.matrix_diag_part
118+
extract_image_patches = tf.extract_image_patches
116119

117120

118121
def make_xval(shape):
@@ -6361,5 +6364,22 @@ def func(tensor, indices, updates):
63616364
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val})
63626365
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val})
63636366

6367+
@check_opset_min_version(9, "EyeLike and ConstantOfShape")
6368+
def test_extract_image_patches(self):
6369+
for rates in [[1, 1], [1, 4], [4, 1], [3, 3]]:
6370+
for _, padding, x_shape, sizes, strides in get_conv_getdata():
6371+
def func(x):
6372+
return extract_image_patches(
6373+
x,
6374+
sizes=sizes,
6375+
strides=strides,
6376+
rates=[1] + rates + [1],
6377+
padding=padding,
6378+
name=_TFOUTPUT
6379+
)
6380+
6381+
x_val = make_xval(x_shape)
6382+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
6383+
63646384
if __name__ == '__main__':
63656385
unittest_main()

tf2onnx/onnx_opset/nn.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2091,3 +2091,93 @@ def version_11(cls, ctx, node, **kwargs):
20912091
ctx.replace_all_inputs(node.output[3], sum_max_neg)
20922092

20932093
ctx.remove_node(node.name)
2094+
2095+
2096+
@tf_op("ExtractImagePatches")
2097+
class ExtractImagePatches:
2098+
@classmethod
2099+
def version_9(cls, ctx, node, **kwargs):
2100+
input_shape = ctx.get_shape(node.input[0])
2101+
output_shape = node.output_shapes[0]
2102+
2103+
sizes = node.get_attr_value("ksizes")
2104+
strides = node.get_attr_value("strides")
2105+
rates = node.get_attr_value("rates")
2106+
padding = node.get_attr_str("padding")
2107+
2108+
# This implementation of ExtractImagePatches does not generalize
2109+
# to outputs that are empty. For example:
2110+
#
2111+
# tf.image.extract_patches(
2112+
# np.random.rand(1, 1, 1, 1), sizes=[1, 2, 2, 1], strides=[1, 1, 1, 1],
2113+
# rates=[1, 1, 1, 1], padding="VALID"
2114+
# )
2115+
#
2116+
# succeeds with the output of:
2117+
#
2118+
# <tf.Tensor: shape=(1, 0, 0, 4), dtype=float64, numpy=array([], shape=(1, 0, 0, 4), dtype=float64)>
2119+
#
2120+
# whereas attempting the same here results in an "Invalid input shape" error for the "Conv" node.
2121+
utils.make_sure(0 not in output_shape, "Empty ExtractImagePatches output is unsupported.")
2122+
[_, size_rows, size_cols, _] = sizes
2123+
2124+
# Transform input into [N * C, H, W, 1].
2125+
transformed_input = ctx.make_node("Reshape", inputs=[
2126+
ctx.make_node("Transpose", inputs=node.input, attr=dict(perm=[0, 3, 1, 2])).output[0],
2127+
ctx.make_const(utils.make_name("new_shape"), np.int64([
2128+
input_shape[0] * input_shape[3],
2129+
input_shape[1],
2130+
input_shape[2],
2131+
1,
2132+
])).output[0],
2133+
])
2134+
2135+
# Create identity kernel.
2136+
k = size_rows * size_cols
2137+
identity_kernel = ctx.make_node("Reshape", inputs=[
2138+
ctx.make_node("EyeLike", inputs=[
2139+
ctx.make_node("ConstantOfShape", inputs=[
2140+
ctx.make_const(utils.make_name("eye_size"), np.array([k, k], dtype=np.int64)).output[0],
2141+
]).output[0],
2142+
]).output[0],
2143+
ctx.make_const(utils.make_name("new_shape"), np.array([
2144+
size_rows,
2145+
size_cols,
2146+
1,
2147+
k,
2148+
], dtype=np.int64)).output[0],
2149+
])
2150+
2151+
# Construct placeholder convolution node and transform into [N * C, K, ?H, ?W].
2152+
convolution = ctx.make_node("Conv", inputs=[transformed_input.output[0], identity_kernel.output[0]],
2153+
shapes=[[input_shape[0] * input_shape[3], output_shape[1], output_shape[2], k]],
2154+
attr=dict(strides=strides, dilations=rates, padding=padding, data_format="NHWC"),
2155+
dtypes=node.output_dtypes)
2156+
2157+
# Transform into [N, ?H, ?W, C * K].
2158+
output_node = ctx.make_node("Reshape", inputs=[
2159+
ctx.make_node("Transpose", inputs=[
2160+
ctx.make_node("Reshape", inputs=[
2161+
convolution.output[0],
2162+
ctx.make_const(utils.make_name("new_shape"), np.array([
2163+
input_shape[0],
2164+
input_shape[3],
2165+
output_shape[1],
2166+
output_shape[2],
2167+
k,
2168+
], dtype=np.int64)).output[0],
2169+
]).output[0],
2170+
], attr=dict(perm=[0, 2, 3, 4, 1])).output[0],
2171+
ctx.make_const(utils.make_name("new_shape"), np.array(output_shape, dtype=np.int64)).output[0],
2172+
])
2173+
2174+
# Replace original node.
2175+
ctx.replace_all_inputs(node.output[0], output_node.output[0])
2176+
ctx.remove_node(node.name)
2177+
2178+
# Transform convolution node.
2179+
kernel_shape = conv_kernel_shape(ctx, convolution, 1)
2180+
strides = conv_dims_attr(convolution, "strides")
2181+
dilations = conv_dims_attr(convolution, "dilations")
2182+
add_padding(ctx, convolution, kernel_shape, strides, dilations)
2183+
conv_convert_inputs(ctx, convolution, with_kernel=True)

0 commit comments

Comments
 (0)