From cb008443422319bd55ad6e20e1e025bea687759b Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 24 Oct 2024 11:42:19 +0000 Subject: [PATCH 1/2] [MLIR][Tensor] Fix Chained tensor.cast canonicalization pattern This commit fixes the bug with the chained tensor.cast canonicalization pattern. When the sourceType and itermediateType both contains a dim which is static and not equal then the joinShapes utility returns a null value. And, this null value during the next call to the joinShapes utility results in a crash. Although, this instance of tensor.cast is invalid since the operand shape and result shape are incompatible but in any case the code should not crash, and this commit particularly fixes this kind of case. Signed-Off-By: Vivek Khandelwal --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 603e86ca3d766..13af1497d3790 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -434,17 +434,23 @@ struct ChainedTensorCast : public OpRewritePattern { // We can remove the intermediate cast if joining all three produces the // same result as just joining the source and result shapes. auto firstJoin = - joinShapes(joinShapes(sourceType, intermediateType), resultType); + joinShapes(sourceType, intermediateType); // The join might not exist if the cast sequence would fail at runtime. if (!firstJoin) return failure(); + auto secondJoin = joinShapes(firstJoin, resultType); + + // The join might not exist if the cast sequence would fail at runtime. + if (!secondJoin) + return failure(); + // The newJoin always exists if the above join exists, it might just contain // less information. If so, we cannot drop the intermediate cast, as doing // so would remove runtime checks. auto newJoin = joinShapes(sourceType, resultType); - if (firstJoin != newJoin) + if (secondJoin != newJoin) return failure(); rewriter.replaceOpWithNewOp(tensorCast, resultType, From e882bcbb5bdd2deae771e839ce9faa1906f1ff3a Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 24 Oct 2024 11:53:25 +0000 Subject: [PATCH 2/2] Fix code formatting --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 13af1497d3790..d1b73ff2dbd0c 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -433,8 +433,7 @@ struct ChainedTensorCast : public OpRewritePattern { // We can remove the intermediate cast if joining all three produces the // same result as just joining the source and result shapes. - auto firstJoin = - joinShapes(sourceType, intermediateType); + auto firstJoin = joinShapes(sourceType, intermediateType); // The join might not exist if the cast sequence would fail at runtime. if (!firstJoin)