Skip to content

Commit 09fe452

Browse files
make TmaCreateDescriptorOp can use static box, add folder function to it and add tests.
1 parent aeb06c6 commit 09fe452

File tree

7 files changed

+106
-7
lines changed

7 files changed

+106
-7
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,12 +546,14 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
546546
}];
547547

548548
let arguments = (ins AnyUnrankedMemRef:$tensor,
549-
Variadic<Index>:$boxDimensions);
549+
Variadic<Index>:$boxDimensions,
550+
DenseI64ArrayAttr:$static_boxDimensions);
550551
let results = (outs NVGPU_TensorMapDescriptor:$tensorMap);
551552
let assemblyFormat = [{
552-
$tensor `box` `[` $boxDimensions `]` attr-dict `:` type($tensor) `->` type($tensorMap)
553+
$tensor `box` custom<DynamicIndexList>($boxDimensions, $static_boxDimensions) attr-dict `:` type($tensor) `->` type($tensorMap)
553554
}];
554555
let hasVerifier = 1;
556+
let hasFolder = 1;
555557
}
556558

557559
def NVGPU_WarpgroupGenerateDescriptorOp : NVGPU_Op<"warpgroup.generate.descriptor", []> {

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1183,9 +1183,17 @@ struct NVGPUTmaCreateDescriptorOpLowering
11831183

11841184
Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
11851185
makeI64Const(b, 5));
1186-
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1186+
unsigned idx = 0;
1187+
ValueRange dynamicDim = adaptor.getBoxDimensions();
1188+
for (auto [index, shape] :
1189+
llvm::enumerate(adaptor.getStaticBoxDimensions())) {
11871190
Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
11881191
boxArrayPtr, makeI64Const(b, index));
1192+
Value value;
1193+
if (ShapedType::isDynamic(shape))
1194+
value = dynamicDim[idx++];
1195+
else
1196+
value = makeI64Const(b, shape);
11891197
b.create<LLVM::StoreOp>(value, gep);
11901198
}
11911199

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/PatternMatch.h"
2424
#include "mlir/IR/TypeUtilities.h"
2525
#include "mlir/IR/Verifier.h"
26+
#include "mlir/Interfaces/ViewLikeInterface.h"
2627
#include "llvm/ADT/STLExtras.h"
2728
#include "llvm/ADT/StringExtras.h"
2829
#include "llvm/ADT/TypeSwitch.h"
@@ -458,6 +459,10 @@ LogicalResult TmaAsyncStoreOp::verify() {
458459
return success();
459460
}
460461

462+
//===----------------------------------------------------------------------===//
463+
// NVGPU_TmaAsyncStoreOp
464+
//===----------------------------------------------------------------------===//
465+
461466
LogicalResult TmaCreateDescriptorOp::verify() {
462467
if (getBoxDimensions().size() > kMaxTMATensorDimension) {
463468
return emitError() << "Maximum " << kMaxTMATensorDimension
@@ -472,6 +477,48 @@ LogicalResult TmaCreateDescriptorOp::verify() {
472477
return success();
473478
}
474479

480+
static Value
481+
TmaCreateDescriptorFoldBoxConstant(TmaCreateDescriptorOp op,
482+
TmaCreateDescriptorOp::FoldAdaptor adaptor) {
483+
std::vector<int64_t> staticBoxDimensions = op.getStaticBoxDimensions().vec();
484+
OperandRange dynamicBoxDimensions = op.getBoxDimensions();
485+
SmallVector<Value> operands = {op.getTensor()};
486+
ArrayRef<Attribute> dynamicBoxDimensionAttrs = adaptor.getBoxDimensions();
487+
if (staticBoxDimensions.empty())
488+
return {};
489+
490+
// `opChange` is a flag. If it is true, it means to update `op` in place.
491+
bool opChange = false;
492+
unsigned idx = 0;
493+
494+
for (unsigned i = 0, e = staticBoxDimensions.size(); i < e; ++i) {
495+
if (!ShapedType::isDynamic(staticBoxDimensions[i]))
496+
continue;
497+
Attribute dynamicBoxDimensionAttr = dynamicBoxDimensionAttrs[idx];
498+
Value dynamicDimension = dynamicBoxDimensions[idx++];
499+
if (auto attr =
500+
mlir::dyn_cast_if_present<IntegerAttr>(dynamicBoxDimensionAttr)) {
501+
staticBoxDimensions[i] = attr.getInt();
502+
opChange = true;
503+
continue;
504+
}
505+
operands.push_back(dynamicDimension);
506+
}
507+
508+
if (opChange) {
509+
op.setStaticBoxDimensions(staticBoxDimensions);
510+
op.getOperation()->setOperands(operands);
511+
return op.getResult();
512+
}
513+
return {};
514+
}
515+
516+
OpFoldResult TmaCreateDescriptorOp::fold(FoldAdaptor adaptor) {
517+
if (auto val = TmaCreateDescriptorFoldBoxConstant(*this, adaptor))
518+
return val;
519+
return OpFoldResult();
520+
}
521+
475522
//===----------------------------------------------------------------------===//
476523
// NVGPU_WarpgroupGenerateDescriptorOp
477524
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,7 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
962962
SmallVector<Value> sizes =
963963
getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
964964

965+
SmallVector<int64_t> static_dims(sizes.size(), ShapedType::kDynamic);
965966
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
966967
Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
967968
loc,
@@ -972,7 +973,7 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
972973
TensorMapSwizzleKind::SWIZZLE_NONE,
973974
TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
974975
TensorMapInterleaveKind::INTERLEAVE_NONE),
975-
unrankedMemRef, sizes);
976+
unrankedMemRef, sizes, static_dims);
976977
return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
977978
}
978979

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,32 @@ func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : m
813813
func.return
814814
}
815815

816+
func.func @create_tensor_map_constant_box_dim(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
817+
%devicePtr2d_unranked = memref.cast %devicePtr2d : memref<64x128xf32> to memref<*xf32>
818+
// CHECK: %[[C5_0:.*]] = llvm.mlir.constant(5 : i32) : i64
819+
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[C5_0]] x i64 : (i64) -> !llvm.ptr
820+
// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i32) : i64
821+
// CHECK: %[[GEP_0:.*]] = llvm.getelementptr %[[ALLOCA]]{{\[}}%[[C0_0]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.ptr
822+
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : i32) : i64
823+
// CHECK: llvm.store %[[C64]], %[[GEP_0]] : i64, !llvm.ptr
824+
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i64
825+
// CHECK: %[[GEP_1:.*]] = llvm.getelementptr %[[ALLOCA]]{{\[}}%[[C1]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.ptr
826+
// CHECK: %[[C128_0:.*]] = llvm.mlir.constant(128 : i32) : i64
827+
// CHECK: llvm.store %[[C128_0]], %[[GEP_1]] : i64, !llvm.ptr
828+
// CHECK: llvm.call @mgpuTensorMapEncodeTiledMemref({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ALLOCA]])
829+
%tensorMap2d = nvgpu.tma.create.descriptor %devicePtr2d_unranked box[64, 128] : memref<*xf32> -> !tensorMap2d
830+
%devicePtr1d_unranked = memref.cast %devicePtr1d : memref<128xf32> to memref<*xf32>
831+
// CHECK: %[[C5_1:.*]] = llvm.mlir.constant(5 : i32) : i64
832+
// CHECK: %[[ALLOCA_1:.*]] = llvm.alloca %[[C5_1]] x i64 : (i64) -> !llvm.ptr
833+
// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i32) : i64
834+
// CHECK: %[[GEP_2:.*]] = llvm.getelementptr %[[ALLOCA_1]]{{\[}}%[[C0_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.ptr
835+
// CHECK: %[[C128_1:.*]] = llvm.mlir.constant(128 : i32) : i64
836+
// CHECK: llvm.store %[[C128_1]], %[[GEP_2]] : i64, !llvm.ptr
837+
// CHECK: llvm.call @mgpuTensorMapEncodeTiledMemref({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ALLOCA_1]])
838+
%tensorMap1d = nvgpu.tma.create.descriptor %devicePtr1d_unranked box[128] : memref<*xf32> -> !tensorMap1d
839+
func.return
840+
}
841+
816842
// CHECK-LABEL: @tma_prefetch(
817843
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none>, %[[arg1:[a-zA-Z0-9_]+]]: i1
818844
func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {

mlir/test/Dialect/NVGPU/canonicalization.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,19 @@ gpu.module @main_kernel {
2727
nvvm.cp.async.bulk.wait_group 0
2828
gpu.return
2929
}
30-
}
30+
}
31+
32+
// -----
33+
34+
!descriptor = !nvgpu.tensormap.descriptor<tensor = memref<64x16xf16, 3>, swizzle = none, l2promo=none, oob=zero, interleave=none>
35+
36+
func.func @main() {
37+
%a_host = memref.alloc() : memref<64x16xf16>
38+
%c16 = arith.constant 16 : index
39+
%c64 = arith.constant 64 : index
40+
%a_device = gpu.alloc() : memref<64x16xf16>
41+
%a_device_unranked = memref.cast %a_device : memref<64x16xf16> to memref<*xf16>
42+
// CHECK: nvgpu.tma.create.descriptor %{{.*}} box [64, 16]
43+
%a_device_map = nvgpu.tma.create.descriptor %a_device_unranked box[%c64, %c16] : memref<*xf16> -> !descriptor
44+
return
45+
}

mlir/test/Dialect/NVGPU/tmaload-transform.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ func.func @main() {
1818
// CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
1919
// CHECK: %[[c64:.*]] = arith.constant 64 : index
2020
// CHECK: %[[c32:.*]] = arith.constant 32 : index
21-
// CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c32]]]
21+
// CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box [%[[c64]], %[[c32]]]
2222
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
2323
// CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
2424
// CHECK: %[[c8_2:.*]] = arith.constant 8 : index
2525
// CHECK: %[[c32_2:.*]] = arith.constant 32 : index
26-
// CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c32_2]]]
26+
// CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box [%[[c8_2]], %[[c32_2]]]
2727
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
2828
// CHECK: gpu.launch
2929
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)

0 commit comments

Comments
 (0)