@@ -2091,3 +2091,81 @@ def version_11(cls, ctx, node, **kwargs):
2091
2091
ctx .replace_all_inputs (node .output [3 ], sum_max_neg )
2092
2092
2093
2093
ctx .remove_node (node .name )
2094
+
2095
+
2096
+ @tf_op ("ExtractImagePatches" )
2097
+ class ExtractImagePatches :
2098
+ @classmethod
2099
+ def version_9 (cls , ctx , node , ** kwargs ):
2100
+ input_shape = ctx .get_shape (node .input [0 ])
2101
+ output_shape = node .output_shapes [0 ]
2102
+
2103
+ sizes = node .get_attr_value ("ksizes" )
2104
+ strides = node .get_attr_value ("strides" )
2105
+ rates = node .get_attr_value ("rates" )
2106
+ padding = node .get_attr_str ("padding" )
2107
+
2108
+ # Our constraints.
2109
+ utils .make_sure (0 not in output_shape , "Empty ExtractImagePatches output is unsupported." )
2110
+ [_ , size_rows , size_cols , _ ] = sizes
2111
+
2112
+ # Transform input into [N * C, H, W, 1].
2113
+ transformed_input = ctx .make_node ("Reshape" , inputs = [
2114
+ ctx .make_node ("Transpose" , inputs = node .input , attr = dict (perm = [0 , 3 , 1 , 2 ])).output [0 ],
2115
+ ctx .make_const (utils .make_name ("new_shape" ), np .int64 ([
2116
+ input_shape [0 ] * input_shape [3 ],
2117
+ input_shape [1 ],
2118
+ input_shape [2 ],
2119
+ 1 ,
2120
+ ])).output [0 ],
2121
+ ])
2122
+
2123
+ # Create identity kernel.
2124
+ k = size_rows * size_cols
2125
+ identity_kernel = ctx .make_node ("Reshape" , inputs = [
2126
+ ctx .make_node ("EyeLike" , inputs = [
2127
+ ctx .make_node ("ConstantOfShape" , inputs = [
2128
+ ctx .make_const (utils .make_name ("eye_size" ), np .array ([k , k ], dtype = np .int64 )).output [0 ],
2129
+ ]).output [0 ],
2130
+ ]).output [0 ],
2131
+ ctx .make_const (utils .make_name ("new_shape" ), np .array ([
2132
+ size_rows ,
2133
+ size_cols ,
2134
+ 1 ,
2135
+ k ,
2136
+ ], dtype = np .int64 )).output [0 ],
2137
+ ])
2138
+
2139
+ # Construct placeholder convolution node and transform into [N * C, K, ?H, ?W].
2140
+ convolution = ctx .make_node ("Conv" , inputs = [transformed_input .output [0 ], identity_kernel .output [0 ]],
2141
+ shapes = [[input_shape [0 ] * input_shape [3 ], output_shape [1 ], output_shape [2 ], k ]],
2142
+ attr = dict (strides = strides , dilations = rates , padding = padding , data_format = "NHWC" ),
2143
+ dtypes = node .output_dtypes )
2144
+
2145
+ # Transform into [N, ?H, ?W, C * K].
2146
+ output_node = ctx .make_node ("Reshape" , inputs = [
2147
+ ctx .make_node ("Transpose" , inputs = [
2148
+ ctx .make_node ("Reshape" , inputs = [
2149
+ convolution .output [0 ],
2150
+ ctx .make_const (utils .make_name ("new_shape" ), np .array ([
2151
+ input_shape [0 ],
2152
+ input_shape [3 ],
2153
+ output_shape [1 ],
2154
+ output_shape [2 ],
2155
+ k ,
2156
+ ], dtype = np .int64 )).output [0 ],
2157
+ ]).output [0 ],
2158
+ ], attr = dict (perm = [0 , 2 , 3 , 4 , 1 ])).output [0 ],
2159
+ ctx .make_const (utils .make_name ("new_shape" ), np .array (output_shape , dtype = np .int64 )).output [0 ],
2160
+ ])
2161
+
2162
+ # Replace original node.
2163
+ ctx .replace_all_inputs (node .output [0 ], output_node .output [0 ])
2164
+ ctx .remove_node (node .name )
2165
+
2166
+ # Transform convolution node.
2167
+ kernel_shape = conv_kernel_shape (ctx , convolution , 1 )
2168
+ strides = conv_dims_attr (convolution , "strides" )
2169
+ dilations = conv_dims_attr (convolution , "dilations" )
2170
+ add_padding (ctx , convolution , kernel_shape , strides , dilations )
2171
+ conv_convert_inputs (ctx , convolution , with_kernel = True )
0 commit comments