From a599e60aff0c9981c4173db69e01eada2091cc23 Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 10 Jul 2025 11:31:44 -0700 Subject: [PATCH] add miscellaneous changes --- .../Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp | 2 +- mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp | 2 +- mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 3 ++- mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir | 6 +++--- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 62022bfb7df1e..22dd3bd0ea98f 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -123,7 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp(); + arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index d2c94b124cdfb..bcd62acf6b9ce 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( Type elementType = getElementTypeOrSelf(memref.getType()); auto vt = VectorType::get(vectorShape, elementType); - Value res = b.create(loc, vt, loads[0]); + Value res = b.create(loc, vt, loads[0]); foreachIndividualVectorElement( res, /*applyFn=*/ diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 2c48a735bf956..9ebac21770136 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -438,7 +438,8 @@ struct UnrollCreateDescOp : public UnrollPattern { // Compute the offset Value inc = rewriter.create( loc, i * blockedChunkSize); - Value incVec = rewriter.create(loc, indiceType, inc); + Value incVec = + rewriter.create(loc, indiceType, inc); Value offsetIndice = rewriter.create(loc, indice, incVec); diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir index 07e03f3b8473d..bbe27fe1b99d9 100644 --- a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir +++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir @@ -20,14 +20,14 @@ func.func @matmul_16x8x4xf32_global( // CHECK: %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]] // CHECK: %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]] // CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32> -// CHECK: %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32> +// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_6]] : f32 to vector<2x1xf32> // CHECK: %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32> // CHECK: %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32> // // CHECK: %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]] // CHECK: %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]] // CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32> -// CHECK: %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32> +// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_15]] : f32 to vector<1x1xf32> // CHECK: %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32> // // CHECK: %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]] @@ -42,7 +42,7 @@ func.func @matmul_16x8x4xf32_global( // CHECK: %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]] // CHECK: %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]] // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32> -// CHECK: %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32> +// CHECK: %[[VAL_30:.*]] = vector.broadcast %[[VAL_20]] : f32 to vector<2x2xf32> // CHECK: %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32> // CHECK: %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32> // CHECK: %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>