Skip to content

Commit e60c192

Browse files
committed
[Tosa] : Equalize all operands for select.
1 parent 80a3dfd commit e60c192

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5069,7 +5069,8 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
50695069
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
50705070

50715071
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, self).failed() ||
5072-
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, other).failed())
5072+
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, other).failed() ||
5073+
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, other).failed())
50735074
return rewriter.notifyMatchFailure(
50745075
op, "Failed to equalize ranks among operands and result");
50755076

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,28 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to
14211421
return %0 : !torch.vtensor<[1,12,5,5],f32>
14221422
}
14231423

1424+
// -----
1425+
// CHECK-LABEL: func.func @torch.aten.where.self_differing_rank_inputs(
1426+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,4],i1>,
1427+
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>,
1428+
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,3,1,1,5,4],f32>) -> !torch.vtensor<[1,3,1,1,5,4],f32> {
1429+
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1,3,1,1,5,4],f32> -> tensor<1x3x1x1x5x4xf32>
1430+
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],f32> -> tensor<f32>
1431+
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4],i1> -> tensor<5x4xi1>
1432+
// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
1433+
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_6]] : (tensor<f32>, !tosa.shape<2>) -> tensor<1x1xf32>
1434+
// CHECK: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 1, 5, 4]> : tensor<6xindex>} : () -> !tosa.shape<6>
1435+
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_8]] : (tensor<5x4xi1>, !tosa.shape<6>) -> tensor<1x1x1x1x5x4xi1>
1436+
// CHECK: %[[VAL_10:.*]] = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
1437+
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_10]] : (tensor<1x1xf32>, !tosa.shape<6>) -> tensor<1x1x1x1x1x1xf32>
1438+
// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_9]], %[[VAL_11]], %[[VAL_3]] : (tensor<1x1x1x1x5x4xi1>, tensor<1x1x1x1x1x1xf32>, tensor<1x3x1x1x5x4xf32>) -> tensor<1x3x1x1x5x4xf32>
1439+
// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x3x1x1x5x4xf32> -> !torch.vtensor<[1,3,1,1,5,4],f32>
1440+
// CHECK: return %[[VAL_13]]
1441+
func.func @torch.aten.where.self_differing_rank_inputs(%40: !torch.vtensor<[5,4],i1>, %41: !torch.vtensor<[],f32>, %38 : !torch.vtensor<[1,3,1,1,5,4],f32>) -> (!torch.vtensor<[1,3,1,1,5,4],f32>) {
1442+
%42 = torch.aten.where.self %40, %41, %38 : !torch.vtensor<[5,4],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[1,3,1,1,5,4],f32> -> !torch.vtensor<[1,3,1,1,5,4],f32>
1443+
return %42: !torch.vtensor<[1,3,1,1,5,4],f32>
1444+
}
1445+
14241446
// -----
14251447
// CHECK-LABEL: func.func @torch.aten.remainder.Scalar(
14261448
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {

0 commit comments

Comments
 (0)