Skip to content

Commit 79cedf9

Browse files
authored
Change the way to validate keep_num_dims attribute for new tf. (#2367)
* Change the way to validate keep_num_dims attribute for new tf. Signed-off-by: Jay Zhang <jiz@microsoft.com> * Fix a lint issue. Signed-off-by: Jay Zhang <jiz@microsoft.com> --------- Signed-off-by: Jay Zhang <jiz@microsoft.com>
1 parent f85e88e commit 79cedf9

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

tf2onnx/tflite_handlers/tfl_math.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,10 @@ def to_tf(cls, ctx, node, **kwargs):
201201
separate_fused_activation_function(ctx, node)
202202
utils.make_sure(node.attr['weights_format'].s == b'DEFAULT',
203203
"Only default weights format supported for fully connected op")
204-
utils.make_sure(node.attr['keep_num_dims'].i == 0,
205-
"Only keep_num_dims=False supported for fully connected op")
206204
if node.attr['asymmetric_quantize_inputs'].i == 1:
207205
dynamic_quantize_inputs(ctx, node)
208206

209-
if ctx.get_rank(node.input[0]) != 2:
207+
if node.attr['keep_num_dims'].i == 0 and ctx.get_rank(node.input[0]) != 2:
210208
# When a fullyconnected node has keep_num_dims=0 and input[0] rank > 2, the extra dims must be compressed
211209
utils.make_sure(ctx.get_rank(node.input[1]) == 2, "weights for FullyConnected must have rank 2")
212210
weights_shape = ctx.get_shape(node.input[1])[1]

0 commit comments

Comments
 (0)