Skip to content

Commit 1c22382

Browse files
authored
[mlir][tosa] Fix transpose_conv2d verifier when output channels are dynamic (#147062)
This commit fixes a transpose_conv2d verifier check which compares the output channels size to the bias size. The check didn't make sure output channels were static before performing the comparison. This lead to failures such as: ``` 'tosa.transpose_conv2d' op bias channels expected to be equal to output channels (-9223372036854775808) or 1, got 5 ``` when the output channels size was dynamic.
1 parent 29e14c3 commit 1c22382

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3417,7 +3417,8 @@ LogicalResult TransposeConv2DOp::verify() {
34173417
return success();
34183418

34193419
const int64_t outputChannels = outputType.getDimSize(3);
3420-
if (biasChannels != outputChannels && biasChannels != 1)
3420+
if (!ShapedType::isDynamic(outputChannels) &&
3421+
biasChannels != outputChannels && biasChannels != 1)
34213422
return emitOpError(
34223423
"bias channels expected to be equal to output channels (")
34233424
<< outputChannels << ") or 1, got " << biasChannels;

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,15 @@ func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x
10321032

10331033
// -----
10341034

1035+
// CHECK-LABEL: @transpose_conv2d_dynamic_out_channels
1036+
func.func @transpose_conv2d_dynamic_out_channels(%arg0: tensor<2x1x1x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
1037+
// CHECK: -> tensor<2x3x6x5xf32>
1038+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x1x1x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x3x6x?xf32>
1039+
return
1040+
}
1041+
1042+
// -----
1043+
10351044
// CHECK-LABEL: @resize_int_horizontal
10361045
func.func @resize_int_horizontal(%arg0: tensor<1x15x13x1xi8>) {
10371046
%scale = tosa.const_shape { values = dense<[11, 7, 89, 6]> : tensor<4xindex> } : () -> !tosa.shape<4>

0 commit comments

Comments
 (0)