Skip to content

Commit 826ee25

Browse files
committed
fix arith+smt+create impmlicit
1 parent 5e24434 commit 826ee25

File tree

10 files changed

+269
-167
lines changed

10 files changed

+269
-167
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,21 @@ class AffineDmaStartOp
114114
AffineMap tagMap, ValueRange tagIndices, Value numElements,
115115
Value stride = nullptr, Value elementsPerStride = nullptr);
116116

117+
static AffineDmaStartOp
118+
create(OpBuilder &builder, Location location, Value srcMemRef,
119+
AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
120+
AffineMap dstMap, ValueRange destIndices, Value tagMemRef,
121+
AffineMap tagMap, ValueRange tagIndices, Value numElements,
122+
Value stride = nullptr, Value elementsPerStride = nullptr);
123+
124+
static AffineDmaStartOp create(ImplicitLocOpBuilder &builder, Value srcMemRef,
125+
AffineMap srcMap, ValueRange srcIndices,
126+
Value destMemRef, AffineMap dstMap,
127+
ValueRange destIndices, Value tagMemRef,
128+
AffineMap tagMap, ValueRange tagIndices,
129+
Value numElements, Value stride = nullptr,
130+
Value elementsPerStride = nullptr);
131+
117132
/// Returns the operand index of the source memref.
118133
unsigned getSrcMemRefOperandIndex() { return 0; }
119134

@@ -319,6 +334,12 @@ class AffineDmaWaitOp
319334

320335
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
321336
AffineMap tagMap, ValueRange tagIndices, Value numElements);
337+
static AffineDmaWaitOp create(OpBuilder &builder, Location location,
338+
Value tagMemRef, AffineMap tagMap,
339+
ValueRange tagIndices, Value numElements);
340+
static AffineDmaWaitOp create(ImplicitLocOpBuilder &builder, Value tagMemRef,
341+
AffineMap tagMap, ValueRange tagIndices,
342+
Value numElements);
322343

323344
static StringRef getOperationName() { return "affine.dma_wait"; }
324345

mlir/include/mlir/Dialect/Arith/IR/Arith.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,25 @@ class ConstantIntOp : public arith::ConstantOp {
6161
unsigned width);
6262
static ConstantIntOp create(OpBuilder &builder, Location location,
6363
int64_t value, unsigned width);
64+
static ConstantIntOp create(ImplicitLocOpBuilder &builder, int64_t value,
65+
unsigned width);
6466

6567
/// Build a constant int op that produces an integer of the specified type,
6668
/// which must be an integer type.
6769
static void build(OpBuilder &builder, OperationState &result, Type type,
6870
int64_t value);
6971
static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
7072
int64_t value);
73+
static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
74+
int64_t value);
7175

7276
/// Build a constant int op that produces an integer from an APInt
7377
static void build(OpBuilder &builder, OperationState &result, Type type,
7478
const APInt &value);
7579
static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
7680
const APInt &value);
81+
static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
82+
const APInt &value);
7783

7884
inline int64_t value() {
7985
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
@@ -93,6 +99,8 @@ class ConstantFloatOp : public arith::ConstantOp {
9399
const APFloat &value);
94100
static ConstantFloatOp create(OpBuilder &builder, Location location,
95101
FloatType type, const APFloat &value);
102+
static ConstantFloatOp create(ImplicitLocOpBuilder &builder, FloatType type,
103+
const APFloat &value);
96104

97105
inline APFloat value() {
98106
return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
@@ -110,6 +118,7 @@ class ConstantIndexOp : public arith::ConstantOp {
110118
static void build(OpBuilder &builder, OperationState &result, int64_t value);
111119
static ConstantIndexOp create(OpBuilder &builder, Location location,
112120
int64_t value);
121+
static ConstantIndexOp create(ImplicitLocOpBuilder &builder, int64_t value);
113122

114123
inline int64_t value() {
115124
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2121
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
2222
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
23+
#include "mlir/IR/Builders.h"
2324
#include "mlir/IR/BuiltinTypes.h"
24-
#include "mlir/IR/ImplicitLocOpBuilder.h"
2525
#include "mlir/IR/PatternMatch.h"
2626
#include "mlir/IR/TypeUtilities.h"
2727
#include "mlir/IR/Value.h"
@@ -114,7 +114,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
114114

115115
auto makeConst = [&](int32_t index) -> Value {
116116
return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
117-
rewriter.getI32IntegerAttr(index));
117+
rewriter.getI32IntegerAttr(index));
118118
};
119119

120120
if (arrayType) {
@@ -147,11 +147,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
147147
Value x1 =
148148
LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
149149
Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
150-
i * 2 + 1);
150+
i * 2 + 1);
151151
vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
152-
x1, makeConst(0));
152+
x1, makeConst(0));
153153
vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
154-
x2, makeConst(1));
154+
x2, makeConst(1));
155155
elements.push_back(vec);
156156
}
157157
}
@@ -160,7 +160,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
160160
Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
161161
for (const auto &el : llvm::enumerate(elements)) {
162162
result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
163-
el.index());
163+
el.index());
164164
}
165165
return result;
166166
}
@@ -208,8 +208,8 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
208208
innerArrayTy.getElementType() == f32Ty)) {
209209
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210210
idx < innerSize; idx++) {
211-
result.push_back(LLVM::ExtractElementOp::create(b,
212-
toUse,
211+
result.push_back(LLVM::ExtractElementOp::create(
212+
b, toUse,
213213
LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
214214
}
215215
continue;
@@ -285,8 +285,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
285285
Value srcPtr =
286286
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
287287
adaptor.getSrcMemref(), adaptor.getIndices());
288-
Value ldMatrixResult = NVVM::LdMatrixOp::create(b,
289-
ldMatrixResultType, srcPtr,
288+
Value ldMatrixResult = NVVM::LdMatrixOp::create(
289+
b, ldMatrixResultType, srcPtr,
290290
/*num=*/op.getNumTiles(),
291291
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
292292
: NVVM::MMALayout::row);
@@ -375,16 +375,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
375375
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376376
Type intrinsicResTy = inferIntrinsicResultType(
377377
typeConverter->convertType(op->getResultTypes()[0]));
378-
Value intrinsicResult = NVVM::MmaOp::create(b,
379-
intrinsicResTy, matA, matB, matC,
380-
/*shape=*/gemmShape,
381-
/*b1Op=*/std::nullopt,
382-
/*intOverflow=*/overflow,
383-
/*multiplicandPtxTypes=*/
384-
std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385-
/*multiplicandLayouts=*/
386-
std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387-
NVVM::MMALayout::col});
378+
Value intrinsicResult =
379+
NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
380+
/*shape=*/gemmShape,
381+
/*b1Op=*/std::nullopt,
382+
/*intOverflow=*/overflow,
383+
/*multiplicandPtxTypes=*/
384+
std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385+
/*multiplicandLayouts=*/
386+
std::array<NVVM::MMALayout, 2>{
387+
NVVM::MMALayout::row, NVVM::MMALayout::col});
388388
rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389389
desiredRetTy, intrinsicResult,
390390
rewriter));
@@ -566,14 +566,15 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
566566
asmVals.push_back(indexData);
567567

568568
return LLVM::InlineAsmOp::create(b,
569-
/*resultTypes=*/intrinsicResultType,
570-
/*operands=*/asmVals,
571-
/*asm_string=*/asmStr,
572-
/*constraints=*/constraintStr,
573-
/*has_side_effects=*/true,
574-
/*is_align_stack=*/false, LLVM::TailCallKind::None,
575-
/*asm_dialect=*/asmDialectAttr,
576-
/*operand_attrs=*/ArrayAttr());
569+
/*resultTypes=*/intrinsicResultType,
570+
/*operands=*/asmVals,
571+
/*asm_string=*/asmStr,
572+
/*constraints=*/constraintStr,
573+
/*has_side_effects=*/true,
574+
/*is_align_stack=*/false,
575+
LLVM::TailCallKind::None,
576+
/*asm_dialect=*/asmDialectAttr,
577+
/*operand_attrs=*/ArrayAttr());
577578
}
578579

579580
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -698,12 +699,12 @@ struct NVGPUAsyncCopyLowering
698699
// filled with zeros.
699700
Value c3I32 =
700701
LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
701-
Value bitwidth = LLVM::ConstantOp::create(b,
702-
b.getI32Type(),
702+
Value bitwidth = LLVM::ConstantOp::create(
703+
b, b.getI32Type(),
703704
b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
704705
Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
705-
srcBytes = LLVM::LShrOp::create(b,
706-
LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
706+
srcBytes = LLVM::LShrOp::create(
707+
b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
707708
}
708709
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
709710
// 16 dst bytes.
@@ -712,14 +713,15 @@ struct NVGPUAsyncCopyLowering
712713
? NVVM::LoadCacheModifierKind::CG
713714
: NVVM::LoadCacheModifierKind::CA;
714715

715-
NVVM::CpAsyncOp::create(b,
716-
dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
716+
NVVM::CpAsyncOp::create(
717+
b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
717718
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
718719
srcBytes);
719720

720721
// Drop the result token.
721-
Value zero = LLVM::ConstantOp::create(b,
722-
IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
722+
Value zero =
723+
LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
724+
rewriter.getI32IntegerAttr(0));
723725
rewriter.replaceOp(op, zero);
724726
return success();
725727
}
@@ -735,9 +737,9 @@ struct NVGPUAsyncCreateGroupLowering
735737
ConversionPatternRewriter &rewriter) const override {
736738
NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
737739
// Drop the result token.
738-
Value zero = LLVM::ConstantOp::create(rewriter,
739-
op->getLoc(), IntegerType::get(op.getContext(), 32),
740-
rewriter.getI32IntegerAttr(0));
740+
Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
741+
IntegerType::get(op.getContext(), 32),
742+
rewriter.getI32IntegerAttr(0));
741743
rewriter.replaceOp(op, zero);
742744
return success();
743745
}
@@ -771,8 +773,8 @@ struct NVGPUMBarrierCreateLowering
771773
SymbolTable symbolTable(moduleOp);
772774
OpBuilder::InsertionGuard guard(rewriter);
773775
rewriter.setInsertionPoint(&moduleOp.front());
774-
auto global = memref::GlobalOp::create(rewriter,
775-
funcOp->getLoc(), "__mbarrier",
776+
auto global = memref::GlobalOp::create(
777+
rewriter, funcOp->getLoc(), "__mbarrier",
776778
/*sym_visibility=*/rewriter.getStringAttr("private"),
777779
/*type=*/barrierType,
778780
/*initial_value=*/ElementsAttr(),
@@ -1119,7 +1121,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
11191121

11201122
static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
11211123
return LLVM::ConstantOp::create(b, b.getIntegerType(64),
1122-
b.getI32IntegerAttr(index));
1124+
b.getI32IntegerAttr(index));
11231125
}
11241126

11251127
/// Returns a Value that holds data type enum that is expected by CUDA driver.
@@ -1182,11 +1184,11 @@ struct NVGPUTmaCreateDescriptorOpLowering
11821184
auto promotedOperands = getTypeConverter()->promoteOperands(
11831185
b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
11841186

1185-
Value boxArrayPtr = LLVM::AllocaOp::create(b, llvmPointerType, llvmInt64Type,
1186-
makeI64Const(b, 5));
1187+
Value boxArrayPtr = LLVM::AllocaOp::create(
1188+
b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
11871189
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
11881190
Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1189-
boxArrayPtr, makeI64Const(b, index));
1191+
boxArrayPtr, makeI64Const(b, index));
11901192
LLVM::StoreOp::create(b, value, gep);
11911193
}
11921194

@@ -1430,9 +1432,9 @@ struct NVGPUWarpgroupMmaOpLowering
14301432
auto overflow = NVVM::MMAIntOverflowAttr::get(
14311433
op->getContext(), NVVM::MMAIntOverflow::wrapped);
14321434

1433-
return NVVM::WgmmaMmaAsyncOp::create(b,
1434-
matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1435-
itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1435+
return NVVM::WgmmaMmaAsyncOp::create(
1436+
b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
1437+
itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
14361438
overflow);
14371439
}
14381440

@@ -1444,15 +1446,16 @@ struct NVGPUWarpgroupMmaOpLowering
14441446
// Perform GEMM
14451447
SmallVector<Value> wgmmaResults;
14461448
for (int i = 0; i < iterationM; ++i) {
1447-
Value matrixC = LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1449+
Value matrixC =
1450+
LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
14481451
for (int j = 0; j < iterationN; ++j)
14491452
for (int k = 0; k < iterationK; ++k)
14501453
matrixC = generateWgmma(i, j, k, matrixC);
14511454
wgmmaResults.push_back(matrixC);
14521455
}
14531456
for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
14541457
wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
1455-
wgmmaResult, matrix, idx);
1458+
wgmmaResult, matrix, idx);
14561459
}
14571460
return wgmmaResult;
14581461
}
@@ -1486,9 +1489,9 @@ struct NVGPUWarpgroupMmaOpLowering
14861489
/// (WgmmaGroupSyncAlignedOp) for group synchronization
14871490
/// (WgmmaWaitGroupSyncOp) after the instructions.
14881491
Value generateWarpgroupMma() {
1489-
NVVM::WgmmaFenceAlignedOp::create(b, );
1492+
NVVM::WgmmaFenceAlignedOp::create(b);
14901493
Value wgmmaResult = generateWgmmaGroup();
1491-
NVVM::WgmmaGroupSyncAlignedOp::create(b, );
1494+
NVVM::WgmmaGroupSyncAlignedOp::create(b);
14921495
NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
14931496
return wgmmaResult;
14941497
}
@@ -1626,7 +1629,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
16261629
auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
16271630
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
16281631
auto structType = cast<LLVM::LLVMStructType>(matrixD);
1629-
Value innerStructValue = LLVM::ExtractValueOp::create(b, matriDValue, idx);
1632+
Value innerStructValue =
1633+
LLVM::ExtractValueOp::create(b, matriDValue, idx);
16301634
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
16311635
offset += structType.getBody().size();
16321636
}
@@ -1656,15 +1660,15 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
16561660
auto structType = cast<LLVM::LLVMStructType>(s);
16571661
Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
16581662
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1659-
structValue = LLVM::InsertValueOp::create(b,
1660-
structType, structValue, zero, ArrayRef<int64_t>({i}));
1663+
structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1664+
zero, ArrayRef<int64_t>({i}));
16611665
}
16621666
innerStructs.push_back(structValue);
16631667
}
16641668
// Pack the inner structs into a single struct
16651669
for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
16661670
packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
1667-
packStruct, matrix, idx);
1671+
packStruct, matrix, idx);
16681672
}
16691673
rewriter.replaceOp(op, packStruct);
16701674
return success();

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ class ExecutionModePattern
709709
loc, llvmI32Type,
710710
rewriter.getI32IntegerAttr(
711711
static_cast<uint32_t>(executionModeAttr.getValue())));
712-
structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
712+
structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
713713
executionMode, 0);
714714

715715
// Insert extra operands if they exist into execution mode info struct.

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ static Value createFPReductionComparisonOpLowering(
670670

671671
if (accumulator) {
672672
result =
673-
typename VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(rewriter,
673+
rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
674674
loc, result, accumulator);
675675
}
676676

0 commit comments

Comments
 (0)