diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 214d2ba7e1b8e..5bbe6704aac48 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5922,10 +5922,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { return bcastOp.getSource(); } - // shape_cast(constant) -> constant + // shape_cast(constant) -> constant, + // if element type of the source and result are the same if (auto splatAttr = - llvm::dyn_cast_if_present(adaptor.getSource())) - return splatAttr.reshape(getType()); + llvm::dyn_cast_if_present(adaptor.getSource())) { + if (splatAttr.getElementType() == resultType.getElementType()) + return splatAttr.reshape(getType()); + } // shape_cast(poison) -> poison if (llvm::dyn_cast_if_present(adaptor.getSource())) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 8a9e27378df61..69da8a31d2c9b 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1002,6 +1002,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> { // ----- +// CHECK-LABEL: func @canonicalize_extract_shapecast_different_element_type +func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi8> { + %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8> + // CHECK-NOT: vector.shape_cast + %1 = vector.shape_cast %0 : vector<12xi8> to vector<1x12xi8> + // CHECK-NOT: vector.extract + %2 = vector.extract %1[0] : vector<12xi8> from vector<1x12xi8> + return %2 : vector<12xi8> +} + +// ----- + // CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar // CHECK: vector.broadcast // CHECK-NOT: vector.shape_cast