Skip to content

[mlir][vector] shape_cast(constant) -> constant fold for non-splats #145539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

newling
Copy link
Contributor

@newling newling commented Jun 24, 2025

The folder shape_cast(splat constant) -> splat constant was first introduced here (Nov 2020). In that commit there is a comment to Only handle splat for now. Based on that I assume the intention was to, at a later time, support a general shape_cast(constant) -> constant folder. That is what this PR does

Potential downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants:

  func.func @foo() -> (vector<4xi32>, vector<2x2xi32>) {
    %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1
    %0 = vector.shape_cast %cst : vector<4xi32> to vector<2x2xi32>
    return %cst, %0 : vector<4xi32>, vector<2x2xi32>
  }

gets folded with this new folder to

   func.func @foo() -> (vector<4xi32>, vector<2x2xi32>) {
    %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> # 'large' constant 1
    %cst_0 = arith.constant dense<[[1, 2], [3, 4]]> : vector<2x2xi32> # 'large' constant 2
    return %cst, %cst_0 : vector<4xi32>, vector<2x2xi32>
  }

Notes on the above case:

  1. This only effects the textual IR, the actual values share the same context storage (I've verified this by checking pointer values in the DenseIntOrFPElementsAttrStorage constructor) so no compile-time memory overhead to this folding. I think at the LLVM IR level the constant is shared, too.
  2. This only happens when the pre-folded constant cannot be dead code eliminated (i.e. when it has 2+ uses) which I don't think is common.

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

The folder shape_cast(splat constant) -&gt; splat constant was first introduced here (Nov 2020). In that commit there is a comment to Only handle splat for now. Based on that I assume the intention was to, at a later time, support a general shape_cast(constant) -&gt; constant folder. That is what this PR does

Potential downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants:

  func.func @<!-- -->foo() -&gt; (vector&lt;4xi32&gt;, vector&lt;2x2xi32&gt;) {
    %cst = arith.constant dense&lt;[1, 2, 3, 4]&gt; : vector&lt;4xi32&gt; # 'large' constant 1
    %0 = vector.shape_cast %cst : vector&lt;4xi32&gt; to vector&lt;2x2xi32&gt;
    return %cst, %0 : vector&lt;4xi32&gt;, vector&lt;2x2xi32&gt;
  }

gets folded with this new folder to

   func.func @<!-- -->foo() -&gt; (vector&lt;4xi32&gt;, vector&lt;2x2xi32&gt;) {
    %cst = arith.constant dense&lt;[1, 2, 3, 4]&gt; : vector&lt;4xi32&gt; # 'large' constant 1
    %cst_0 = arith.constant dense&lt;[[1, 2], [3, 4]]&gt; : vector&lt;2x2xi32&gt; # 'large' constant 2
    return %cst, %cst_0 : vector&lt;4xi32&gt;, vector&lt;2x2xi32&gt;
  }

Notes on the above case:

  1. This only effects the textual IR, the actual values share the same context storage (I've verified this by checking pointer values in the DenseIntOrFPElementsAttrStorage constructor) so no compile-time memory overhead to this folding. I think at the LLVM IR level the constant is shared, too.
  2. This only happens when the pre-folded constant cannot be dead code eliminated (i.e. when it has 2+ uses) which I don't think is common.

Full diff: https://github.com/llvm/llvm-project/pull/145539.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4-5)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+34-4)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ee9ab61b670c4..ddc80063fd340 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5881,14 +5881,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
   }
 
   // shape_cast(constant) -> constant
-  if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
-    return splatAttr.reshape(getType());
+  if (auto denseAttr =
+          dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
+    return denseAttr.reshape(getType());
 
   // shape_cast(poison) -> poison
-  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
     return ub::PoisonAttr::get(getContext());
-  }
 
   return {};
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..a06a98ee1b93b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1219,11 +1219,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
 
 // -----
 
-// CHECK-LABEL: shape_cast_constant
+// CHECK-LABEL: shape_cast_splat_constant
 //       CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
 //       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
 //       CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
-func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
   %cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
   %cst_1 = arith.constant dense<1> : vector<12x2xi32>
   %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1233,6 +1233,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
 
 // -----
 
+// Test of shape_cast's fold method:
+// shape_cast(constant) -> constant.
+//
+// CHECK-LABEL: @shape_cast_dense_int_constant
+//               CHECK: %[[CST:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
+//               CHECK: return %[[CST]] : vector<2x3xi8>
+func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
+  %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
+  %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
+  return %0 : vector<2x3xi8>
+}
+
+// -----
+
+// Test of shape_cast fold's method:
+// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
+//
+// CHECK-LABEL: @shape_cast_dense_float_constant
+//  CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
+//  CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
+//      CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
+func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
+  %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
+  %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
+  return %0, %cst : vector<2xf32>, vector<1x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: shape_cast_poison
 //       CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
 //       CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
@@ -1549,7 +1579,7 @@ func.func @negative_store_to_load_tensor_memref(
     %arg0 : tensor<?x?xf32>,
     %arg1 : memref<?x?xf32>,
     %v0 : vector<4x2xf32>
-  ) -> vector<4x2xf32> 
+  ) -> vector<4x2xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
@@ -1606,7 +1636,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
 //       CHECK:   vector.transfer_read
 func.func @negative_store_to_load_tensor_broadcast_masked(
     %arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
-  -> vector<4x2x6xf32> 
+  -> vector<4x2x6xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32

@newling
Copy link
Contributor Author

newling commented Jul 8, 2025

Ping

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Potential downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants:

You could remove this potentially negative impact (*) by checking the number of users and only folding when there's one user. I wouldn't unless you have a case where this is harmful.

Please wait ~1day before landing, just in case somebody else wants to chime in.

(*) Further discussion implies that the impact would be negligible 🤷🏻

@newling
Copy link
Contributor Author

newling commented Jul 10, 2025

LGTM, thanks!

Potential downside: It is possible with this folder end up with, instead of 1 large constant and 1 shape_cast, 2 large constants:

You could remove this potentially negative impact (*) by checking the number of users and only folding when there's one user. I wouldn't unless you have a case where this is harmful.

Please wait ~1day before landing, just in case somebody else wants to chime in.

(*) Further discussion implies that the impact would be negligible 🤷🏻

Thanks for reviewing!

My preference would be to keep it as it is, rather than condition on 1 user.

Two reasons,

  1. permuting 2 lines in a pass
 auto v0 = createOrFold<ShapeCast>(myConst);
 auto v1 = create<AnotherOp>(myConst); 

might result in different IR which seems unintuitive.

  1. Better not to make shape_cast a special case. Consider
func.func @foo() -> (vector<4xi32>, vector<4xi32>) {
 %c0 = arith.constant dense<[1, 2, 3, -4]> : vector<4xi32>
 %foo = math.absi %c0 : vector<4xi32>
 return %c0, %foo : vector<4xi32>, vector<4xi32>
}

This is folded to

%cst = arith.constant dense<[1, 2, 3, -4]> : vector<4xi32>
%cst_0 = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
return %cst, %cst_0 : vector<4xi32>, vector<4xi32>

so maybe the same sort of thing should just happen with vector.shape_cast? There's maybe even an argument for it being more sensible for shape_cast, where the underlying storage will be the same, which isn't the true in the math.absi case above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants