Skip to content

Commit c30b5b1

Browse files
[mlir][GPU][transform] Add gpu_to_rocdl conversion pattern (#146962)
Co-authored-by: Son Tuan Vu <vuson@google.com>
1 parent 19860ce commit c30b5b1

File tree

4 files changed

+55
-0
lines changed

4 files changed

+55
-0
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ def ApplyGPUSubgroupReduceToNVVMConversionPatternsOp : Op<Transform_Dialect,
5454
let assemblyFormat = "attr-dict";
5555
}
5656

57+
def ApplyGPUToROCDLConversionPatternsOp : Op<Transform_Dialect,
58+
"apply_conversion_patterns.gpu.gpu_to_rocdl",
59+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
60+
["verifyTypeConverter"]>]> {
61+
let description = [{
62+
Collects patterns that convert GPU dialect ops to ROCDL dialect ops. These
63+
patterns require an "LLVMTypeConverter".
64+
}];
65+
let arguments = (ins StrAttr:$chipset);
66+
let assemblyFormat = [{
67+
`chipset` `=` $chipset attr-dict
68+
}];
69+
}
70+
5771
//===----------------------------------------------------------------------===//
5872
// Apply...PatternsOp
5973
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ add_mlir_dialect_library(MLIRGPUTransformOps
2424
# ConversionPatterns
2525
MLIRNVGPUToNVVM
2626
MLIRGPUToNVVMTransforms
27+
MLIRGPUToROCDLTransforms
2728
)

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1212
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
13+
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
1314
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1415
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1718
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
1819
#include "mlir/Dialect/GPU/Transforms/Passes.h"
1920
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
21+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
2022
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2123
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
2224
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -40,6 +42,7 @@
4042
#include "llvm/Support/Debug.h"
4143
#include "llvm/Support/ErrorHandling.h"
4244
#include "llvm/Support/InterleavedRange.h"
45+
#include "llvm/Support/LogicalResult.h"
4346
#include <type_traits>
4447

4548
using namespace mlir;
@@ -127,6 +130,41 @@ LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
127130
return success();
128131
}
129132

133+
void transform::ApplyGPUToROCDLConversionPatternsOp::populatePatterns(
134+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
135+
auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
136+
populateGpuMemorySpaceAttributeConversions(
137+
llvmTypeConverter, [](AddressSpace space) {
138+
switch (space) {
139+
case AddressSpace::Global:
140+
return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
141+
case AddressSpace::Workgroup:
142+
return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
143+
case AddressSpace::Private:
144+
return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
145+
}
146+
llvm_unreachable("unknown address space enum value");
147+
});
148+
FailureOr<amdgpu::Chipset> maybeChipset =
149+
amdgpu::Chipset::parse(getChipset());
150+
assert(llvm::succeeded(maybeChipset) && "expected valid chipset");
151+
populateGpuToROCDLConversionPatterns(
152+
llvmTypeConverter, patterns, mlir::gpu::amd::Runtime::HIP, *maybeChipset);
153+
}
154+
155+
LogicalResult
156+
transform::ApplyGPUToROCDLConversionPatternsOp::verifyTypeConverter(
157+
transform::TypeConverterBuilderOpInterface builder) {
158+
FailureOr<amdgpu::Chipset> maybeChipset =
159+
amdgpu::Chipset::parse(getChipset());
160+
if (failed(maybeChipset)) {
161+
return emitOpError("Invalid chipset name: " + getChipset());
162+
}
163+
if (builder.getTypeConverterType() != "LLVMTypeConverter")
164+
return emitOpError("expected LLVMTypeConverter");
165+
return success();
166+
}
167+
130168
//===----------------------------------------------------------------------===//
131169
// Apply...PatternsOp
132170
//===----------------------------------------------------------------------===//s

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5505,13 +5505,15 @@ cc_library(
55055505
":GPUDialect",
55065506
":GPUToGPURuntimeTransforms",
55075507
":GPUToNVVMTransforms",
5508+
":GPUToROCDLTransforms",
55085509
":GPUTransformOpsIncGen",
55095510
":GPUTransforms",
55105511
":IR",
55115512
":LLVMCommonConversion",
55125513
":MemRefDialect",
55135514
":NVGPUDialect",
55145515
":NVVMDialect",
5516+
":ROCDLDialect",
55155517
":SCFDialect",
55165518
":Support",
55175519
":TransformDialect",

0 commit comments

Comments
 (0)