@@ -2091,3 +2091,93 @@ def version_11(cls, ctx, node, **kwargs):
2091
2091
ctx .replace_all_inputs (node .output [3 ], sum_max_neg )
2092
2092
2093
2093
ctx .remove_node (node .name )
2094
+
2095
+
2096
+ @tf_op ("ExtractImagePatches" )
2097
+ class ExtractImagePatches :
2098
+ @classmethod
2099
+ def version_9 (cls , ctx , node , ** kwargs ):
2100
+ input_shape = ctx .get_shape (node .input [0 ])
2101
+ output_shape = node .output_shapes [0 ]
2102
+
2103
+ sizes = node .get_attr_value ("ksizes" )
2104
+ strides = node .get_attr_value ("strides" )
2105
+ rates = node .get_attr_value ("rates" )
2106
+ padding = node .get_attr_str ("padding" )
2107
+
2108
+ # This implementation of ExtractImagePatches does not generalize
2109
+ # to outputs that are empty. For example:
2110
+ #
2111
+ # tf.image.extract_patches(
2112
+ # np.random.rand(1, 1, 1, 1), sizes=[1, 2, 2, 1], strides=[1, 1, 1, 1],
2113
+ # rates=[1, 1, 1, 1], padding="VALID"
2114
+ # )
2115
+ #
2116
+ # succeeds with the output of:
2117
+ #
2118
+ # <tf.Tensor: shape=(1, 0, 0, 4), dtype=float64, numpy=array([], shape=(1, 0, 0, 4), dtype=float64)>
2119
+ #
2120
+ # whereas attempting the same here results in an "Invalid input shape" error for the "Conv" node.
2121
+ utils .make_sure (0 not in output_shape , "Empty ExtractImagePatches output is unsupported." )
2122
+ [_ , size_rows , size_cols , _ ] = sizes
2123
+
2124
+ # Transform input into [N * C, H, W, 1].
2125
+ transformed_input = ctx .make_node ("Reshape" , inputs = [
2126
+ ctx .make_node ("Transpose" , inputs = node .input , attr = dict (perm = [0 , 3 , 1 , 2 ])).output [0 ],
2127
+ ctx .make_const (utils .make_name ("new_shape" ), np .int64 ([
2128
+ input_shape [0 ] * input_shape [3 ],
2129
+ input_shape [1 ],
2130
+ input_shape [2 ],
2131
+ 1 ,
2132
+ ])).output [0 ],
2133
+ ])
2134
+
2135
+ # Create identity kernel.
2136
+ k = size_rows * size_cols
2137
+ identity_kernel = ctx .make_node ("Reshape" , inputs = [
2138
+ ctx .make_node ("EyeLike" , inputs = [
2139
+ ctx .make_node ("ConstantOfShape" , inputs = [
2140
+ ctx .make_const (utils .make_name ("eye_size" ), np .array ([k , k ], dtype = np .int64 )).output [0 ],
2141
+ ]).output [0 ],
2142
+ ]).output [0 ],
2143
+ ctx .make_const (utils .make_name ("new_shape" ), np .array ([
2144
+ size_rows ,
2145
+ size_cols ,
2146
+ 1 ,
2147
+ k ,
2148
+ ], dtype = np .int64 )).output [0 ],
2149
+ ])
2150
+
2151
+ # Construct placeholder convolution node and transform into [N * C, K, ?H, ?W].
2152
+ convolution = ctx .make_node ("Conv" , inputs = [transformed_input .output [0 ], identity_kernel .output [0 ]],
2153
+ shapes = [[input_shape [0 ] * input_shape [3 ], output_shape [1 ], output_shape [2 ], k ]],
2154
+ attr = dict (strides = strides , dilations = rates , padding = padding , data_format = "NHWC" ),
2155
+ dtypes = node .output_dtypes )
2156
+
2157
+ # Transform into [N, ?H, ?W, C * K].
2158
+ output_node = ctx .make_node ("Reshape" , inputs = [
2159
+ ctx .make_node ("Transpose" , inputs = [
2160
+ ctx .make_node ("Reshape" , inputs = [
2161
+ convolution .output [0 ],
2162
+ ctx .make_const (utils .make_name ("new_shape" ), np .array ([
2163
+ input_shape [0 ],
2164
+ input_shape [3 ],
2165
+ output_shape [1 ],
2166
+ output_shape [2 ],
2167
+ k ,
2168
+ ], dtype = np .int64 )).output [0 ],
2169
+ ]).output [0 ],
2170
+ ], attr = dict (perm = [0 , 2 , 3 , 4 , 1 ])).output [0 ],
2171
+ ctx .make_const (utils .make_name ("new_shape" ), np .array (output_shape , dtype = np .int64 )).output [0 ],
2172
+ ])
2173
+
2174
+ # Replace original node.
2175
+ ctx .replace_all_inputs (node .output [0 ], output_node .output [0 ])
2176
+ ctx .remove_node (node .name )
2177
+
2178
+ # Transform convolution node.
2179
+ kernel_shape = conv_kernel_shape (ctx , convolution , 1 )
2180
+ strides = conv_dims_attr (convolution , "strides" )
2181
+ dilations = conv_dims_attr (convolution , "dilations" )
2182
+ add_padding (ctx , convolution , kernel_shape , strides , dilations )
2183
+ conv_convert_inputs (ctx , convolution , with_kernel = True )
0 commit comments