Skip to content

Commit 3aa08e8

Browse files
committed
feat: add support for ExtractImagePatches
Signed-off-by: Nanoskript <96655713+nanoskript@users.noreply.github.com>
1 parent 25c977c commit 3aa08e8

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

tests/test_backend.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
matrix_diag_part = tf.compat.v1.matrix_diag_part
7373
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
7474
fake_quant_with_min_max_vars = tf.quantization.fake_quant_with_min_max_vars
75+
extract_image_patches = tf.image.extract_patches
7576
elif Version(tf.__version__) >= Version("1.13"):
7677
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
7778
conv3d_transpose = tf.compat.v1.nn.conv3d_transpose
@@ -94,6 +95,7 @@
9495
matrix_diag_part = tf.compat.v1.matrix_diag_part
9596
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
9697
fake_quant_with_min_max_vars = tf.compat.v1.quantization.fake_quant_with_min_max_vars
98+
extract_image_patches = tf.compat.v1.extract_image_patches
9799
else:
98100
conv2d_backprop_input = tf.nn.conv2d_backprop_input
99101
conv3d_transpose = tf.nn.conv3d_transpose
@@ -111,6 +113,7 @@
111113
is_inf = tf.is_inf
112114
floormod = tf.floormod
113115
matrix_diag_part = tf.matrix_diag_part
116+
extract_image_patches = tf.extract_image_patches
114117

115118

116119
def make_xval(shape):
@@ -6283,5 +6286,21 @@ def func(tensor, indices, updates):
62836286
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val})
62846287
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val})
62856288

6289+
def test_extract_image_patches(self):
6290+
for rates in [[1, 1], [1, 4], [4, 1], [3, 3]]:
6291+
for _, padding, x_shape, sizes, strides in get_conv_getdata():
6292+
def func(x):
6293+
return extract_image_patches(
6294+
x,
6295+
sizes=sizes,
6296+
strides=strides,
6297+
rates=[1] + rates + [1],
6298+
padding=padding,
6299+
name=_TFOUTPUT
6300+
)
6301+
6302+
x_val = make_xval(x_shape)
6303+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
6304+
62866305
if __name__ == '__main__':
62876306
unittest_main()

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tf2onnx.rewriter.lstm_tf2_rewriter import rewriter_lstm_tf2
2525
from tf2onnx.rewriter.gru_tf2_rewriter import rewrite_gru_tf2
2626
from tf2onnx.rewriter.fused_op_rewriter import rewrite_fused_ops
27+
from tf2onnx.rewriter.extract_image_patches_rewriter import rewrite_extract_image_patches
2728

2829

2930
__all__ = [
@@ -53,4 +54,5 @@
5354
"rewriter_lstm_tf2",
5455
"rewrite_gru_tf2",
5556
"rewrite_fused_ops",
57+
"rewrite_extract_image_patches",
5658
]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.rewriter.extract_image_patches_rewriter - Rewrites ExtractImagePatches into supported operations.
6+
"""
7+
8+
import numpy as np
9+
from tf2onnx import utils
10+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
11+
12+
13+
# pylint: disable=missing-docstring
14+
15+
def rewrite_extract_image_patches(g, ops):
16+
pattern = OpTypePattern("ExtractImagePatches", name="extract_image_patches")
17+
matcher = GraphMatcher(pattern)
18+
match_results = list(matcher.match_ops(ops))
19+
for match_result in match_results:
20+
operation = match_result.get_op("extract_image_patches")
21+
input_shape = g.get_shape(operation.input[0])
22+
output_shape = operation.output_shapes[0]
23+
24+
sizes = operation.get_attr_value("ksizes")
25+
strides = operation.get_attr_value("strides")
26+
rates = operation.get_attr_value("rates")
27+
padding = operation.get_attr_str("padding")
28+
29+
# Our constraints.
30+
utils.make_sure(0 not in output_shape, "Empty ExtractImagePatches output is unsupported.")
31+
[_, size_rows, size_cols, _] = sizes
32+
33+
# Transform input into [N * C, H, W, 1].
34+
transformed_input = g.make_node("Reshape", inputs=[
35+
g.make_node("Transpose", inputs=operation.input, attr=dict(perm=[0, 3, 1, 2])).output[0],
36+
g.make_const(utils.make_name("new_shape"), np.int64([
37+
input_shape[0] * input_shape[3],
38+
input_shape[1],
39+
input_shape[2],
40+
1,
41+
])).output[0],
42+
])
43+
44+
# Create identity kernel.
45+
k = size_rows * size_cols
46+
identity_kernel = g.make_node("Reshape", inputs=[
47+
g.make_node("EyeLike", inputs=[
48+
g.make_node("ConstantOfShape", inputs=[
49+
g.make_const(utils.make_name("eye_size"), np.array([k, k], dtype=np.int64)).output[0],
50+
]).output[0],
51+
]).output[0],
52+
g.make_const(utils.make_name("new_shape"), np.array([
53+
size_rows,
54+
size_cols,
55+
1,
56+
k,
57+
], dtype=np.int64)).output[0],
58+
])
59+
60+
# Convolve into [N * C, ?H, ?W, K].
61+
convolution = g.make_node("Conv2D", inputs=[transformed_input.output[0], identity_kernel.output[0]],
62+
attr=dict(strides=strides, dilations=rates, padding=padding, data_format="NHWC"),
63+
shapes=[[input_shape[0] * input_shape[3], output_shape[1], output_shape[2], k]],
64+
dtypes=operation.output_dtypes, skip_conversion=False)
65+
66+
# Transform into [N, ?H, ?W, C * K].
67+
output_node = g.make_node("Reshape", inputs=[
68+
g.make_node("Transpose", inputs=[
69+
g.make_node("Reshape", inputs=[
70+
convolution.output[0],
71+
g.make_const(utils.make_name("new_shape"), np.array([
72+
input_shape[0],
73+
input_shape[3],
74+
output_shape[1],
75+
output_shape[2],
76+
k,
77+
], dtype=np.int64)).output[0],
78+
]).output[0],
79+
], attr=dict(perm=[0, 2, 3, 4, 1])).output[0],
80+
g.make_const(utils.make_name("new_shape"), np.array(output_shape, dtype=np.int64)).output[0],
81+
])
82+
83+
# Replace node.
84+
g.replace_all_inputs(operation.output[0], output_node.output[0])
85+
g.remove_node(operation.name)
86+
return g.get_nodes()

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ def compat_handler(ctx, node, **kwargs):
598598
rewriter_lstm_tf2,
599599
rewrite_gru_tf2,
600600
rewrite_single_direction_lstm,
601+
rewrite_extract_image_patches,
601602
# bi-directional
602603
rewrite_bi_direction_lstm,
603604
rewrite_single_direction_gru,

0 commit comments

Comments
 (0)