|
10 | 10 |
|
11 | 11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
12 | 12 | #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
| 13 | +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" |
13 | 14 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
14 | 15 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
15 | 16 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
16 | 17 | #include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
17 | 18 | #include "mlir/Dialect/GPU/TransformOps/Utils.h"
|
18 | 19 | #include "mlir/Dialect/GPU/Transforms/Passes.h"
|
19 | 20 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
| 21 | +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
20 | 22 | #include "mlir/Dialect/MemRef/IR/MemRef.h"
|
21 | 23 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
|
22 | 24 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
|
40 | 42 | #include "llvm/Support/Debug.h"
|
41 | 43 | #include "llvm/Support/ErrorHandling.h"
|
42 | 44 | #include "llvm/Support/InterleavedRange.h"
|
| 45 | +#include "llvm/Support/LogicalResult.h" |
43 | 46 | #include <type_traits>
|
44 | 47 |
|
45 | 48 | using namespace mlir;
|
@@ -127,6 +130,41 @@ LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
|
127 | 130 | return success();
|
128 | 131 | }
|
129 | 132 |
|
| 133 | +void transform::ApplyGPUToROCDLConversionPatternsOp::populatePatterns( |
| 134 | + TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 135 | + auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| 136 | + populateGpuMemorySpaceAttributeConversions( |
| 137 | + llvmTypeConverter, [](AddressSpace space) { |
| 138 | + switch (space) { |
| 139 | + case AddressSpace::Global: |
| 140 | + return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace; |
| 141 | + case AddressSpace::Workgroup: |
| 142 | + return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace; |
| 143 | + case AddressSpace::Private: |
| 144 | + return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace; |
| 145 | + } |
| 146 | + llvm_unreachable("unknown address space enum value"); |
| 147 | + }); |
| 148 | + FailureOr<amdgpu::Chipset> maybeChipset = |
| 149 | + amdgpu::Chipset::parse(getChipset()); |
| 150 | + assert(llvm::succeeded(maybeChipset) && "expected valid chipset"); |
| 151 | + populateGpuToROCDLConversionPatterns( |
| 152 | + llvmTypeConverter, patterns, mlir::gpu::amd::Runtime::HIP, *maybeChipset); |
| 153 | +} |
| 154 | + |
| 155 | +LogicalResult |
| 156 | +transform::ApplyGPUToROCDLConversionPatternsOp::verifyTypeConverter( |
| 157 | + transform::TypeConverterBuilderOpInterface builder) { |
| 158 | + FailureOr<amdgpu::Chipset> maybeChipset = |
| 159 | + amdgpu::Chipset::parse(getChipset()); |
| 160 | + if (failed(maybeChipset)) { |
| 161 | + return emitOpError("Invalid chipset name: " + getChipset()); |
| 162 | + } |
| 163 | + if (builder.getTypeConverterType() != "LLVMTypeConverter") |
| 164 | + return emitOpError("expected LLVMTypeConverter"); |
| 165 | + return success(); |
| 166 | +} |
| 167 | + |
130 | 168 | //===----------------------------------------------------------------------===//
|
131 | 169 | // Apply...PatternsOp
|
132 | 170 | //===----------------------------------------------------------------------===//s
|
|
0 commit comments