diff --git a/tf2onnx/tflite_handlers/tfl_math.py b/tf2onnx/tflite_handlers/tfl_math.py index add2f7de3..35b377c56 100644 --- a/tf2onnx/tflite_handlers/tfl_math.py +++ b/tf2onnx/tflite_handlers/tfl_math.py @@ -201,12 +201,10 @@ def to_tf(cls, ctx, node, **kwargs): separate_fused_activation_function(ctx, node) utils.make_sure(node.attr['weights_format'].s == b'DEFAULT', "Only default weights format supported for fully connected op") - utils.make_sure(node.attr['keep_num_dims'].i == 0, - "Only keep_num_dims=False supported for fully connected op") if node.attr['asymmetric_quantize_inputs'].i == 1: dynamic_quantize_inputs(ctx, node) - if ctx.get_rank(node.input[0]) != 2: + if node.attr['keep_num_dims'].i == 0 and ctx.get_rank(node.input[0]) != 2: # When a fullyconnected node has keep_num_dims=0 and input[0] rank > 2, the extra dims must be compressed utils.make_sure(ctx.get_rank(node.input[1]) == 2, "weights for FullyConnected must have rank 2") weights_shape = ctx.get_shape(node.input[1])[1]