Skip to content

Revert [mlir][vector] Use vector.broadcast in place of vector.splat #148937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 16, 2025

Conversation

newling
Copy link
Contributor

@newling newling commented Jul 15, 2025

This reverts commit a599e60.

This should only be landed after #148027, at which point we don't need to assume that vector.broadcast has been lowered to another form.

@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: James Newling (newling)

Changes

This reverts commit a599e60.

This should only be landed after #148027, at which point we don't need to assume that vector.broadcast has been lowered to another form.


Full diff: https://github.com/llvm/llvm-project/pull/148937.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1-1)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+1-2)
  • (modified) mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 22dd3bd0ea98f..62022bfb7df1e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,7 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
       vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
       [&](Operation *op) { return converter.isLegal(op); });
   target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
-                    arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
+                    arith::ConstantOp, vector::SplatOp>();
 }
 
 void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index bcd62acf6b9ce..d2c94b124cdfb 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
 
   Type elementType = getElementTypeOrSelf(memref.getType());
   auto vt = VectorType::get(vectorShape, elementType);
-  Value res = b.create<vector::BroadcastOp>(loc, vt, loads[0]);
+  Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
   foreachIndividualVectorElement(
       res,
       /*applyFn=*/
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 39e484927052c..dc76441b27c02 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -434,8 +434,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
           // Compute the offset
           Value inc = rewriter.create<arith::ConstantIndexOp>(
               loc, i * blockedChunkSize);
-          Value incVec =
-              rewriter.create<vector::BroadcastOp>(loc, indiceType, inc);
+          Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
           Value offsetIndice =
               rewriter.create<arith::AddIOp>(loc, indice, incVec);
 
diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
index bbe27fe1b99d9..07e03f3b8473d 100644
--- a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -20,14 +20,14 @@ func.func @matmul_16x8x4xf32_global(
 // CHECK:           %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32>
-// CHECK:           %[[VAL_10:.*]] = vector.broadcast %[[VAL_6]] : f32 to vector<2x1xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32>
 // CHECK:           %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32>
 // CHECK:           %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32>
 //
 // CHECK:           %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32>
-// CHECK:           %[[VAL_16:.*]] = vector.broadcast %[[VAL_15]] : f32 to vector<1x1xf32>
+// CHECK:           %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32>
 // CHECK:           %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32>
 //
 // CHECK:           %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
@@ -42,7 +42,7 @@ func.func @matmul_16x8x4xf32_global(
 // CHECK:           %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32>
-// CHECK:           %[[VAL_30:.*]] = vector.broadcast %[[VAL_20]] : f32 to vector<2x2xf32>
+// CHECK:           %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32>
 // CHECK:           %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32>
 // CHECK:           %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32>
 // CHECK:           %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>

@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir-nvgpu

Author: James Newling (newling)

Changes

This reverts commit a599e60.

This should only be landed after #148027, at which point we don't need to assume that vector.broadcast has been lowered to another form.


Full diff: https://github.com/llvm/llvm-project/pull/148937.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1-1)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+1-2)
  • (modified) mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 22dd3bd0ea98f..62022bfb7df1e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,7 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
       vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
       [&](Operation *op) { return converter.isLegal(op); });
   target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
-                    arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
+                    arith::ConstantOp, vector::SplatOp>();
 }
 
 void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index bcd62acf6b9ce..d2c94b124cdfb 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
 
   Type elementType = getElementTypeOrSelf(memref.getType());
   auto vt = VectorType::get(vectorShape, elementType);
-  Value res = b.create<vector::BroadcastOp>(loc, vt, loads[0]);
+  Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
   foreachIndividualVectorElement(
       res,
       /*applyFn=*/
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 39e484927052c..dc76441b27c02 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -434,8 +434,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
           // Compute the offset
           Value inc = rewriter.create<arith::ConstantIndexOp>(
               loc, i * blockedChunkSize);
-          Value incVec =
-              rewriter.create<vector::BroadcastOp>(loc, indiceType, inc);
+          Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
           Value offsetIndice =
               rewriter.create<arith::AddIOp>(loc, indice, incVec);
 
diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
index bbe27fe1b99d9..07e03f3b8473d 100644
--- a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -20,14 +20,14 @@ func.func @matmul_16x8x4xf32_global(
 // CHECK:           %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32>
-// CHECK:           %[[VAL_10:.*]] = vector.broadcast %[[VAL_6]] : f32 to vector<2x1xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32>
 // CHECK:           %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32>
 // CHECK:           %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32>
 //
 // CHECK:           %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32>
-// CHECK:           %[[VAL_16:.*]] = vector.broadcast %[[VAL_15]] : f32 to vector<1x1xf32>
+// CHECK:           %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32>
 // CHECK:           %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32>
 //
 // CHECK:           %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
@@ -42,7 +42,7 @@ func.func @matmul_16x8x4xf32_global(
 // CHECK:           %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32>
-// CHECK:           %[[VAL_30:.*]] = vector.broadcast %[[VAL_20]] : f32 to vector<2x2xf32>
+// CHECK:           %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32>
 // CHECK:           %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32>
 // CHECK:           %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32>
 // CHECK:           %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>

@banach-space
Copy link
Contributor

Hm , where is this enigmatic commit?

$ git log -1
commit d67d91a9906366585162cebf292f923a3f28c8a6 (HEAD -> main, origin/main, origin/HEAD)
Author: Sudharsan Veeravalli <quic_svs@quicinc.com>
Date:   Wed Jul 16 00:31:33 2025 +0530

    [RISCV] Fix issues in ORI to QC.INSBI transformation (#148809)

    The transformation done in #147349 was incorrect since we were not
    passing the input node of the `OR` instruction to the `QC.INSBI`
    instruction leading to the generated instruction doing the wrong thing.
    In order to do this we first needed to add the output register to
    `QC.INSBI` as being both an input and output.

    The code produced after the above fix will need a copy (mv) to preserve
    the register input to the OR instruction if it has more than one use
    making the transformation net neutral ( `6-byte QC.E.ORI/ORAI` vs
    `2-byte C.MV + 4-byte QC.INSB`I). Avoid doing the transformation if
    there is more than one use of the input register to the OR instruction.
$ git show a599e60aff0c9981c4173db69e01eada2091cc23
fatal: bad object a599e60aff0c9981c4173db69e01eada2091cc23

@newling
Copy link
Contributor Author

newling commented Jul 15, 2025

@banach-space can you elaborate? Quite likely I did this revert incorrectly.

@newling
Copy link
Contributor Author

newling commented Jul 15, 2025

maybe I should have used the PR's squashed commit 9987573 instead of the commit a599e60 which is part of the original PR (the PR had only 1 commit).

@banach-space
Copy link
Contributor

@banach-space can you elaborate? Quite likely I did this revert incorrectly.

When I use this link --> a599e60 <--, GitHub doesn't show which PR that commit belongs to. And there is no such commit in-tree.

IIUC, this is reverting a part of #148028?

@newling
Copy link
Contributor Author

newling commented Jul 15, 2025

@banach-space can you elaborate? Quite likely I did this revert incorrectly.

When I use this link --> a599e60 <--, GitHub doesn't show which PR that commit belongs to. And there is no such commit in-tree.

IIUC, this is reverting a part of #148028?

It is reverting the whole of #148028. Which is the commit 9987573 on main, and was the commit a599e60 on the PR branch. So next time I won't mention the commit(s) on the PR branch. Thanks for pointing this mistake out to me!

@newling
Copy link
Contributor Author

newling commented Jul 15, 2025

Windows CI failure seems to be unrelated. I will try again tomorrow

@newling newling merged commit 228c45f into llvm:main Jul 16, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants