Skip to content

Commit 9987573

Browse files
authored
[mlir][vector] Use vector.broadcast in place of vector.splat (#148028)
Part of deprecation of vector.splat RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4 More complete deprecation: #147818
1 parent 027f5ba commit 9987573

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
123123
vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
124124
[&](Operation *op) { return converter.isLegal(op); });
125125
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126-
arith::ConstantOp, vector::SplatOp>();
126+
arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
127127
}
128128

129129
void EmulateUnsupportedFloatsPass::runOnOperation() {

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
688688

689689
Type elementType = getElementTypeOrSelf(memref.getType());
690690
auto vt = VectorType::get(vectorShape, elementType);
691-
Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
691+
Value res = b.create<vector::BroadcastOp>(loc, vt, loads[0]);
692692
foreachIndividualVectorElement(
693693
res,
694694
/*applyFn=*/

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
438438
// Compute the offset
439439
Value inc = rewriter.create<arith::ConstantIndexOp>(
440440
loc, i * blockedChunkSize);
441-
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
441+
Value incVec =
442+
rewriter.create<vector::BroadcastOp>(loc, indiceType, inc);
442443
Value offsetIndice =
443444
rewriter.create<arith::AddIOp>(loc, indice, incVec);
444445

mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ func.func @matmul_16x8x4xf32_global(
2020
// CHECK: %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
2121
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
2222
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32>
23-
// CHECK: %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32>
23+
// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_6]] : f32 to vector<2x1xf32>
2424
// CHECK: %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32>
2525
// CHECK: %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32>
2626
//
2727
// CHECK: %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
2828
// CHECK: %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
2929
// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32>
30-
// CHECK: %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32>
30+
// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_15]] : f32 to vector<1x1xf32>
3131
// CHECK: %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32>
3232
//
3333
// CHECK: %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
@@ -42,7 +42,7 @@ func.func @matmul_16x8x4xf32_global(
4242
// CHECK: %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
4343
// CHECK: %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
4444
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32>
45-
// CHECK: %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32>
45+
// CHECK: %[[VAL_30:.*]] = vector.broadcast %[[VAL_20]] : f32 to vector<2x2xf32>
4646
// CHECK: %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32>
4747
// CHECK: %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32>
4848
// CHECK: %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>

0 commit comments

Comments
 (0)