diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index a0dc8964f7a02..85ab48b37813f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -3429,7 +3429,8 @@ LogicalResult TransposeConv2DOp::verify() { return success(); const int64_t outputChannels = outputType.getDimSize(3); - if (biasChannels != outputChannels && biasChannels != 1) + if (!ShapedType::isDynamic(outputChannels) && + biasChannels != outputChannels && biasChannels != 1) return emitOpError( "bias channels expected to be equal to output channels (") << outputChannels << ") or 1, got " << biasChannels; diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index e280c1155f526..d0f40279421f4 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1032,6 +1032,15 @@ func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x // ----- +// CHECK-LABEL: @transpose_conv2d_dynamic_out_channels +func.func @transpose_conv2d_dynamic_out_channels(%arg0: tensor<2x1x1x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) { + // CHECK: -> tensor<2x3x6x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array, stride = array} : (tensor<2x1x1x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x3x6x?xf32> + return +} + +// ----- + // CHECK-LABEL: @resize_int_horizontal func.func @resize_int_horizontal(%arg0: tensor<1x15x13x1xi8>) { %scale = tosa.const_shape { values = dense<[11, 7, 89, 6]> : tensor<4xindex> } : () -> !tosa.shape<4>