Skip to content

Commit 6c9be27

Browse files
authored
[mlir][tensor] Fold identity reshape of 0d-tensors (#146375)
Just like 1d-tensors, reshapes of 0d-tensors (aka scalars) are always no-folds as they only have one possible layout. This PR adds logic to the `fold` implementation to optimize these away as is currently implemented for 1d tensors.
1 parent 9262ac3 commit 6c9be27

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,9 +1872,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
18721872
if (!sourceTy || !resultTy || sourceTy != resultTy)
18731873
return {};
18741874

1875-
// If the source and result are both 1D tensors and have the same type, the
1876-
// reshape has no effect, even if the tensor is dynamically shaped.
1877-
if (sourceTy.getRank() == 1)
1875+
// If the source and result are both 0D or 1D tensors and have the same type,
1876+
// the reshape has no effect, even if the tensor is dynamically shaped.
1877+
if (sourceTy.getRank() <= 1)
18781878
return source;
18791879

18801880
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,17 @@ func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> te
971971

972972
// -----
973973

974+
// CHECK-LABEL: func @fold_reshape_0d
975+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<f32>
976+
// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<0xindex>
977+
// CHECK: return %[[INPUT]]
978+
func.func @fold_reshape_0d(%input: tensor<f32>, %shape: tensor<0xindex>) -> tensor<f32> {
979+
%0 = tensor.reshape %input(%shape) : (tensor<f32>, tensor<0xindex>) -> tensor<f32>
980+
return %0 : tensor<f32>
981+
}
982+
983+
// -----
984+
974985
// CHECK-LABEL: func @fold_extract_constant_splat
975986
// CHECK-NOT: tensor.extract_slice
976987
// CHECK: arith.constant dense<42> : tensor<4x4xi32>

0 commit comments

Comments
 (0)