20
20
#include " mlir/Dialect/MemRef/IR/MemRef.h"
21
21
#include " mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
22
22
#include " mlir/Dialect/SCF/Transforms/Patterns.h"
23
+ #include " mlir/IR/Builders.h"
23
24
#include " mlir/IR/BuiltinTypes.h"
24
- #include " mlir/IR/ImplicitLocOpBuilder.h"
25
25
#include " mlir/IR/PatternMatch.h"
26
26
#include " mlir/IR/TypeUtilities.h"
27
27
#include " mlir/IR/Value.h"
@@ -114,7 +114,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
114
114
115
115
auto makeConst = [&](int32_t index) -> Value {
116
116
return LLVM::ConstantOp::create (rewriter, loc, IntegerType::get (ctx, 32 ),
117
- rewriter.getI32IntegerAttr (index));
117
+ rewriter.getI32IntegerAttr (index));
118
118
};
119
119
120
120
if (arrayType) {
@@ -147,11 +147,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
147
147
Value x1 =
148
148
LLVM::ExtractValueOp::create (rewriter, loc, intrinsicResult, i * 2 );
149
149
Value x2 = LLVM::ExtractValueOp::create (rewriter, loc, intrinsicResult,
150
- i * 2 + 1 );
150
+ i * 2 + 1 );
151
151
vec = LLVM::InsertElementOp::create (rewriter, loc, vec.getType (), vec,
152
- x1, makeConst (0 ));
152
+ x1, makeConst (0 ));
153
153
vec = LLVM::InsertElementOp::create (rewriter, loc, vec.getType (), vec,
154
- x2, makeConst (1 ));
154
+ x2, makeConst (1 ));
155
155
elements.push_back (vec);
156
156
}
157
157
}
@@ -160,7 +160,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
160
160
Value result = LLVM::PoisonOp::create (rewriter, loc, arrayType);
161
161
for (const auto &el : llvm::enumerate (elements)) {
162
162
result = LLVM::InsertValueOp::create (rewriter, loc, result, el.value (),
163
- el.index ());
163
+ el.index ());
164
164
}
165
165
return result;
166
166
}
@@ -208,8 +208,8 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
208
208
innerArrayTy.getElementType () == f32Ty)) {
209
209
for (unsigned idx = 0 , innerSize = innerArrayTy.getNumElements ();
210
210
idx < innerSize; idx++) {
211
- result.push_back (LLVM::ExtractElementOp::create (b,
212
- toUse,
211
+ result.push_back (LLVM::ExtractElementOp::create (
212
+ b, toUse,
213
213
LLVM::ConstantOp::create (b, i64Ty, b.getI64IntegerAttr (idx))));
214
214
}
215
215
continue ;
@@ -285,8 +285,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
285
285
Value srcPtr =
286
286
getStridedElementPtr (rewriter, b.getLoc (), srcMemrefType,
287
287
adaptor.getSrcMemref (), adaptor.getIndices ());
288
- Value ldMatrixResult = NVVM::LdMatrixOp::create (b,
289
- ldMatrixResultType, srcPtr,
288
+ Value ldMatrixResult = NVVM::LdMatrixOp::create (
289
+ b, ldMatrixResultType, srcPtr,
290
290
/* num=*/ op.getNumTiles (),
291
291
/* layout=*/ op.getTranspose () ? NVVM::MMALayout::col
292
292
: NVVM::MMALayout::row);
@@ -375,16 +375,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
375
375
Type desiredRetTy = typeConverter->convertType (op->getResultTypes ()[0 ]);
376
376
Type intrinsicResTy = inferIntrinsicResultType (
377
377
typeConverter->convertType (op->getResultTypes ()[0 ]));
378
- Value intrinsicResult = NVVM::MmaOp::create (b,
379
- intrinsicResTy, matA, matB, matC,
380
- /* shape=*/ gemmShape,
381
- /* b1Op=*/ std::nullopt,
382
- /* intOverflow=*/ overflow,
383
- /* multiplicandPtxTypes=*/
384
- std::array<NVVM::MMATypes, 2 >{*ptxTypeA, *ptxTypeB},
385
- /* multiplicandLayouts=*/
386
- std::array<NVVM::MMALayout, 2 >{NVVM::MMALayout::row,
387
- NVVM::MMALayout::col});
378
+ Value intrinsicResult =
379
+ NVVM::MmaOp::create (b, intrinsicResTy, matA, matB, matC,
380
+ /* shape=*/ gemmShape,
381
+ /* b1Op=*/ std::nullopt,
382
+ /* intOverflow=*/ overflow,
383
+ /* multiplicandPtxTypes=*/
384
+ std::array<NVVM::MMATypes, 2 >{*ptxTypeA, *ptxTypeB},
385
+ /* multiplicandLayouts=*/
386
+ std::array<NVVM::MMALayout, 2 >{
387
+ NVVM::MMALayout::row, NVVM::MMALayout::col});
388
388
rewriter.replaceOp (op, convertIntrinsicResult (op.getLoc (), intrinsicResTy,
389
389
desiredRetTy, intrinsicResult,
390
390
rewriter));
@@ -566,14 +566,15 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
566
566
asmVals.push_back (indexData);
567
567
568
568
return LLVM::InlineAsmOp::create (b,
569
- /* resultTypes=*/ intrinsicResultType,
570
- /* operands=*/ asmVals,
571
- /* asm_string=*/ asmStr,
572
- /* constraints=*/ constraintStr,
573
- /* has_side_effects=*/ true ,
574
- /* is_align_stack=*/ false , LLVM::TailCallKind::None,
575
- /* asm_dialect=*/ asmDialectAttr,
576
- /* operand_attrs=*/ ArrayAttr ());
569
+ /* resultTypes=*/ intrinsicResultType,
570
+ /* operands=*/ asmVals,
571
+ /* asm_string=*/ asmStr,
572
+ /* constraints=*/ constraintStr,
573
+ /* has_side_effects=*/ true ,
574
+ /* is_align_stack=*/ false ,
575
+ LLVM::TailCallKind::None,
576
+ /* asm_dialect=*/ asmDialectAttr,
577
+ /* operand_attrs=*/ ArrayAttr ());
577
578
}
578
579
579
580
// / Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -698,12 +699,12 @@ struct NVGPUAsyncCopyLowering
698
699
// filled with zeros.
699
700
Value c3I32 =
700
701
LLVM::ConstantOp::create (b, b.getI32Type (), b.getI32IntegerAttr (3 ));
701
- Value bitwidth = LLVM::ConstantOp::create (b,
702
- b.getI32Type (),
702
+ Value bitwidth = LLVM::ConstantOp::create (
703
+ b, b .getI32Type (),
703
704
b.getI32IntegerAttr (srcMemrefType.getElementTypeBitWidth ()));
704
705
Value srcElementsI32 = LLVM::TruncOp::create (b, b.getI32Type (), srcBytes);
705
- srcBytes = LLVM::LShrOp::create (b,
706
- LLVM::MulOp::create (b, bitwidth, srcElementsI32), c3I32);
706
+ srcBytes = LLVM::LShrOp::create (
707
+ b, LLVM::MulOp::create (b, bitwidth, srcElementsI32), c3I32);
707
708
}
708
709
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
709
710
// 16 dst bytes.
@@ -712,14 +713,15 @@ struct NVGPUAsyncCopyLowering
712
713
? NVVM::LoadCacheModifierKind::CG
713
714
: NVVM::LoadCacheModifierKind::CA;
714
715
715
- NVVM::CpAsyncOp::create (b,
716
- dstPtr, scrPtr, rewriter.getI32IntegerAttr (sizeInBytes),
716
+ NVVM::CpAsyncOp::create (
717
+ b, dstPtr, scrPtr, rewriter.getI32IntegerAttr (sizeInBytes),
717
718
NVVM::LoadCacheModifierKindAttr::get (op->getContext (), cacheModifier),
718
719
srcBytes);
719
720
720
721
// Drop the result token.
721
- Value zero = LLVM::ConstantOp::create (b,
722
- IntegerType::get (op.getContext (), 32 ), rewriter.getI32IntegerAttr (0 ));
722
+ Value zero =
723
+ LLVM::ConstantOp::create (b, IntegerType::get (op.getContext (), 32 ),
724
+ rewriter.getI32IntegerAttr (0 ));
723
725
rewriter.replaceOp (op, zero);
724
726
return success ();
725
727
}
@@ -735,9 +737,9 @@ struct NVGPUAsyncCreateGroupLowering
735
737
ConversionPatternRewriter &rewriter) const override {
736
738
NVVM::CpAsyncCommitGroupOp::create (rewriter, op.getLoc ());
737
739
// Drop the result token.
738
- Value zero = LLVM::ConstantOp::create (rewriter,
739
- op-> getLoc (), IntegerType::get (op.getContext (), 32 ),
740
- rewriter.getI32IntegerAttr (0 ));
740
+ Value zero = LLVM::ConstantOp::create (rewriter, op-> getLoc (),
741
+ IntegerType::get (op.getContext (), 32 ),
742
+ rewriter.getI32IntegerAttr (0 ));
741
743
rewriter.replaceOp (op, zero);
742
744
return success ();
743
745
}
@@ -771,8 +773,8 @@ struct NVGPUMBarrierCreateLowering
771
773
SymbolTable symbolTable (moduleOp);
772
774
OpBuilder::InsertionGuard guard (rewriter);
773
775
rewriter.setInsertionPoint (&moduleOp.front ());
774
- auto global = memref::GlobalOp::create (rewriter,
775
- funcOp->getLoc (), " __mbarrier" ,
776
+ auto global = memref::GlobalOp::create (
777
+ rewriter, funcOp->getLoc (), " __mbarrier" ,
776
778
/* sym_visibility=*/ rewriter.getStringAttr (" private" ),
777
779
/* type=*/ barrierType,
778
780
/* initial_value=*/ ElementsAttr (),
@@ -1119,7 +1121,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
1119
1121
1120
1122
static Value makeI64Const (ImplicitLocOpBuilder &b, int32_t index) {
1121
1123
return LLVM::ConstantOp::create (b, b.getIntegerType (64 ),
1122
- b.getI32IntegerAttr (index));
1124
+ b.getI32IntegerAttr (index));
1123
1125
}
1124
1126
1125
1127
// / Returns a Value that holds data type enum that is expected by CUDA driver.
@@ -1182,11 +1184,11 @@ struct NVGPUTmaCreateDescriptorOpLowering
1182
1184
auto promotedOperands = getTypeConverter ()->promoteOperands (
1183
1185
b.getLoc (), op->getOperands (), adaptor.getOperands (), b);
1184
1186
1185
- Value boxArrayPtr = LLVM::AllocaOp::create (b, llvmPointerType, llvmInt64Type,
1186
- makeI64Const (b, 5 ));
1187
+ Value boxArrayPtr = LLVM::AllocaOp::create (
1188
+ b, llvmPointerType, llvmInt64Type, makeI64Const (b, 5 ));
1187
1189
for (auto [index, value] : llvm::enumerate (adaptor.getBoxDimensions ())) {
1188
1190
Value gep = LLVM::GEPOp::create (b, llvmPointerType, llvmPointerType,
1189
- boxArrayPtr, makeI64Const (b, index));
1191
+ boxArrayPtr, makeI64Const (b, index));
1190
1192
LLVM::StoreOp::create (b, value, gep);
1191
1193
}
1192
1194
@@ -1430,9 +1432,9 @@ struct NVGPUWarpgroupMmaOpLowering
1430
1432
auto overflow = NVVM::MMAIntOverflowAttr::get (
1431
1433
op->getContext (), NVVM::MMAIntOverflow::wrapped);
1432
1434
1433
- return NVVM::WgmmaMmaAsyncOp::create (b,
1434
- matrixC.getType (), matrixC, descriptorA, descriptorB, shape, itypeA ,
1435
- itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1435
+ return NVVM::WgmmaMmaAsyncOp::create (
1436
+ b, matrixC.getType (), matrixC, descriptorA, descriptorB, shape,
1437
+ itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1436
1438
overflow);
1437
1439
}
1438
1440
@@ -1444,15 +1446,16 @@ struct NVGPUWarpgroupMmaOpLowering
1444
1446
// Perform GEMM
1445
1447
SmallVector<Value> wgmmaResults;
1446
1448
for (int i = 0 ; i < iterationM; ++i) {
1447
- Value matrixC = LLVM::ExtractValueOp::create (b, adaptor.getMatrixC (), i);
1449
+ Value matrixC =
1450
+ LLVM::ExtractValueOp::create (b, adaptor.getMatrixC (), i);
1448
1451
for (int j = 0 ; j < iterationN; ++j)
1449
1452
for (int k = 0 ; k < iterationK; ++k)
1450
1453
matrixC = generateWgmma (i, j, k, matrixC);
1451
1454
wgmmaResults.push_back (matrixC);
1452
1455
}
1453
1456
for (auto [idx, matrix] : llvm::enumerate (wgmmaResults)) {
1454
1457
wgmmaResult = LLVM::InsertValueOp::create (b, wgmmaResult.getType (),
1455
- wgmmaResult, matrix, idx);
1458
+ wgmmaResult, matrix, idx);
1456
1459
}
1457
1460
return wgmmaResult;
1458
1461
}
@@ -1486,9 +1489,9 @@ struct NVGPUWarpgroupMmaOpLowering
1486
1489
// / (WgmmaGroupSyncAlignedOp) for group synchronization
1487
1490
// / (WgmmaWaitGroupSyncOp) after the instructions.
1488
1491
Value generateWarpgroupMma () {
1489
- NVVM::WgmmaFenceAlignedOp::create (b, );
1492
+ NVVM::WgmmaFenceAlignedOp::create (b);
1490
1493
Value wgmmaResult = generateWgmmaGroup ();
1491
- NVVM::WgmmaGroupSyncAlignedOp::create (b, );
1494
+ NVVM::WgmmaGroupSyncAlignedOp::create (b);
1492
1495
NVVM::WgmmaWaitGroupSyncOp::create (b, op.getWaitGroup ());
1493
1496
return wgmmaResult;
1494
1497
}
@@ -1626,7 +1629,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
1626
1629
auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType ());
1627
1630
for (auto [idx, matrixD] : llvm::enumerate (stype.getBody ())) {
1628
1631
auto structType = cast<LLVM::LLVMStructType>(matrixD);
1629
- Value innerStructValue = LLVM::ExtractValueOp::create (b, matriDValue, idx);
1632
+ Value innerStructValue =
1633
+ LLVM::ExtractValueOp::create (b, matriDValue, idx);
1630
1634
storeFragmentedMatrix (b, innerStructValue, op.getDstMemref (), offset);
1631
1635
offset += structType.getBody ().size ();
1632
1636
}
@@ -1656,15 +1660,15 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1656
1660
auto structType = cast<LLVM::LLVMStructType>(s);
1657
1661
Value structValue = LLVM::ExtractValueOp::create (b, packStruct, idx);
1658
1662
for (unsigned i = 0 ; i < structType.getBody ().size (); ++i) {
1659
- structValue = LLVM::InsertValueOp::create (b,
1660
- structType, structValue, zero, ArrayRef<int64_t >({i}));
1663
+ structValue = LLVM::InsertValueOp::create (b, structType, structValue,
1664
+ zero, ArrayRef<int64_t >({i}));
1661
1665
}
1662
1666
innerStructs.push_back (structValue);
1663
1667
}
1664
1668
// Pack the inner structs into a single struct
1665
1669
for (auto [idx, matrix] : llvm::enumerate (innerStructs)) {
1666
1670
packStruct = LLVM::InsertValueOp::create (b, packStruct.getType (),
1667
- packStruct, matrix, idx);
1671
+ packStruct, matrix, idx);
1668
1672
}
1669
1673
rewriter.replaceOp (op, packStruct);
1670
1674
return success ();
0 commit comments