Skip to content

Commit 357377b

Browse files
committed
Address reviewer comments.
1 parent d9ff9e3 commit 357377b

File tree

4 files changed

+106
-109
lines changed

4 files changed

+106
-109
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1501,7 +1501,7 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
15011501

15021502
def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
15031503
let summary = "Convert XeVM to LLVM dialect";
1504-
let dependentDialects = ["xevm::XeVMDialect", ];
1504+
let dependentDialects = ["LLVM::LLVMDialect", ];
15051505
}
15061506

15071507
#endif // MLIR_CONVERSION_PASSES

mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ class DialectRegistry;
1515
class LLVMTypeConverter;
1616
class RewritePatternSet;
1717
class Pass;
18-
} // namespace mlir
1918

20-
namespace mlir {
2119
#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
2220
#include "mlir/Conversion/Passes.h.inc"
2321

mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 37 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323

2424
#include "llvm/ADT/TypeSwitch.h"
2525

26-
#define DEBUG_TYPE "xevm-to-llvm"
27-
2826
namespace mlir {
2927
#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
3028
#include "mlir/Conversion/Passes.h.inc"
@@ -70,6 +68,9 @@ std::string getTypeMangling(Type ty, bool isUnsigned = false) {
7068
default:
7169
llvm_unreachable("unhandled integer type");
7270
}
71+
})
72+
.Default([](Type) -> std::string {
73+
llvm_unreachable("unhandled type for mangling");
7374
});
7475
}
7576

@@ -165,38 +166,18 @@ int32_t getL3CacheControl(OpType op) {
165166
if constexpr (isLoad) {
166167
switch (*op.getCacheControl()) {
167168
case LoadCacheControl::L1UC_L2UC_L3UC:
168-
control = 1;
169-
break;
170-
case LoadCacheControl::L1UC_L2UC_L3C:
171-
control = 2;
172-
break;
173169
case LoadCacheControl::L1UC_L2C_L3UC:
174-
control = 1;
175-
break;
176-
case LoadCacheControl::L1UC_L2C_L3C:
177-
control = 2;
178-
break;
179170
case LoadCacheControl::L1C_L2UC_L3UC:
180-
control = 1;
181-
break;
182-
case LoadCacheControl::L1C_L2UC_L3C:
183-
control = 2;
184-
break;
185171
case LoadCacheControl::L1C_L2C_L3UC:
186-
control = 1;
187-
break;
188-
case LoadCacheControl::L1C_L2C_L3C:
189-
control = 2;
190-
break;
191172
case LoadCacheControl::L1S_L2UC_L3UC:
192-
control = 1;
193-
break;
194-
case LoadCacheControl::L1S_L2UC_L3C:
195-
control = 2;
196-
break;
197173
case LoadCacheControl::L1S_L2C_L3UC:
198174
control = 1;
199175
break;
176+
case LoadCacheControl::L1UC_L2UC_L3C:
177+
case LoadCacheControl::L1UC_L2C_L3C:
178+
case LoadCacheControl::L1C_L2UC_L3C:
179+
case LoadCacheControl::L1C_L2C_L3C:
180+
case LoadCacheControl::L1S_L2UC_L3C:
200181
case LoadCacheControl::L1S_L2C_L3C:
201182
control = 2;
202183
break;
@@ -209,47 +190,21 @@ int32_t getL3CacheControl(OpType op) {
209190
} else {
210191
switch (*op.getCacheControl()) {
211192
case StoreCacheControl::L1UC_L2UC_L3UC:
212-
control = 1;
213-
break;
214-
case StoreCacheControl::L1UC_L2UC_L3WB:
215-
control = 2;
216-
break;
217193
case StoreCacheControl::L1UC_L2WB_L3UC:
218-
control = 1;
219-
break;
220-
case StoreCacheControl::L1UC_L2WB_L3WB:
221-
control = 2;
222-
break;
223194
case StoreCacheControl::L1WT_L2UC_L3UC:
224-
control = 1;
225-
break;
226-
case StoreCacheControl::L1WT_L2UC_L3WB:
227-
control = 2;
228-
break;
229195
case StoreCacheControl::L1WT_L2WB_L3UC:
230-
control = 1;
231-
break;
232-
case StoreCacheControl::L1WT_L2WB_L3WB:
233-
control = 2;
234-
break;
235196
case StoreCacheControl::L1S_L2UC_L3UC:
236-
control = 1;
237-
break;
238-
case StoreCacheControl::L1S_L2UC_L3WB:
239-
control = 2;
240-
break;
241197
case StoreCacheControl::L1S_L2WB_L3UC:
242-
control = 1;
243-
break;
244-
case StoreCacheControl::L1S_L2WB_L3WB:
245-
control = 2;
246-
break;
247198
case StoreCacheControl::L1WB_L2UC_L3UC:
248-
control = 1;
249-
break;
250199
case StoreCacheControl::L1WB_L2WB_L3UC:
251200
control = 1;
252201
break;
202+
case StoreCacheControl::L1UC_L2UC_L3WB:
203+
case StoreCacheControl::L1UC_L2WB_L3WB:
204+
case StoreCacheControl::L1WT_L2UC_L3WB:
205+
case StoreCacheControl::L1WT_L2WB_L3WB:
206+
case StoreCacheControl::L1S_L2UC_L3WB:
207+
case StoreCacheControl::L1S_L2WB_L3WB:
253208
case StoreCacheControl::L1WB_L2UC_L3WB:
254209
control = 2;
255210
break;
@@ -263,13 +218,8 @@ int32_t getL3CacheControl(OpType op) {
263218
template <bool isLoad, typename OpType>
264219
static std::optional<ArrayAttr>
265220
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
266-
if constexpr (isLoad) {
267-
if (!op.getCacheControl())
268-
return {};
269-
} else {
270-
if (!op.getCacheControl())
271-
return {};
272-
}
221+
if (!op.getCacheControl())
222+
return {};
273223
constexpr int32_t decorationCacheControlArity{4};
274224
constexpr int32_t loadCacheControlKey{6442};
275225
constexpr int32_t storeCacheControlKey{6443};
@@ -289,13 +239,12 @@ static LLVM::CallOp createDeviceFunctionCall(
289239
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
290240
ArrayRef<Type> argTypes, ArrayRef<Value> args,
291241
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
292-
LLVMFuncAttributeOptions funcAttributeOptions) {
242+
LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
293243
auto moduleOp = rewriter.getBlock()
294244
->getParentOp()
295245
->getParentWithTrait<OpTrait::SymbolTable>();
296246
assert(moduleOp && "Expecting module");
297-
MLIRContext *ctx = rewriter.getContext();
298-
Location loc = UnknownLoc::get(ctx);
247+
Location loc = op->getLoc();
299248

300249
auto funcOpRes =
301250
LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
@@ -384,9 +333,10 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
384333
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
385334
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
386335
funcAttrs.memEffectsAttr = memAttr;
387-
Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
388-
args, {}, funcAttrs)
389-
->getResult(0);
336+
Value result =
337+
createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, args, {},
338+
funcAttrs, op.getOperation())
339+
->getResult(0);
390340

391341
if (cOrigTy != cTy)
392342
result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
@@ -419,8 +369,8 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
419369
ConversionPatternRewriter &rewriter) const override {
420370
auto loc = op.getLoc();
421371
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
422-
Value one = rewriter.create<LLVM::ConstantOp>(
423-
loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));
372+
Value one =
373+
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 1);
424374
SmallVector<Value> args{op.getPtr(), one};
425375
SmallVector<Type> argTypes;
426376
for (auto arg : args)
@@ -434,7 +384,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
434384

435385
LLVM::CallOp call = createDeviceFunctionCall(
436386
rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
437-
argTypes, args, {}, funcAttr);
387+
argTypes, args, {}, funcAttr, op.getOperation());
438388
if (std::optional<ArrayAttr> optCacheControls =
439389
getCacheControlMetadata<true>(rewriter, op))
440390
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
@@ -473,17 +423,18 @@ class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
473423
// CLUSTER and SYSTEM are not supported in OpenCL
474424
llvm_unreachable("unsupported xevm::MemoryScope");
475425
}
476-
Value acqRel = rewriter.create<LLVM::ConstantOp>(
477-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(4));
478-
Value memScopeConst = rewriter.create<LLVM::ConstantOp>(
479-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(memScope));
480-
Value addrSpaceConst = rewriter.create<LLVM::ConstantOp>(
481-
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(addrSpace));
426+
Type i32Type = rewriter.getI32Type();
427+
Value acqRel = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 4);
428+
Value memScopeConst =
429+
rewriter.create<LLVM::ConstantOp>(loc, i32Type, memScope);
430+
Value addrSpaceConst =
431+
rewriter.create<LLVM::ConstantOp>(loc, i32Type, addrSpace);
482432
SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
483-
SmallVector<Type> argTypes{3, rewriter.getI32Type()};
433+
SmallVector<Type> argTypes{3, i32Type};
484434
createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
485435
LLVM::LLVMVoidType::get(rewriter.getContext()),
486-
argTypes, args, {}, noUnwindAttrs);
436+
argTypes, args, {}, noUnwindAttrs,
437+
op.getOperation());
487438
rewriter.eraseOp(op);
488439
return success();
489440
}
@@ -512,10 +463,8 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
512463
auto i32Type = rewriter.getI32Type();
513464
Value byteCoord =
514465
rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
515-
Value zero = rewriter.create<LLVM::ConstantOp>(
516-
loc, i32Type, rewriter.getI32IntegerAttr(0));
517-
Value one = rewriter.create<LLVM::ConstantOp>(
518-
loc, i32Type, rewriter.getI32IntegerAttr(1));
466+
Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0);
467+
Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1);
519468
byteCoord = rewriter.create<LLVM::InsertElementOp>(
520469
loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
521470
byteCoord = rewriter.create<LLVM::InsertElementOp>(
@@ -589,7 +538,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
589538
}
590539
LLVM::CallOp call = createDeviceFunctionCall(
591540
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
592-
argTypes, args, paramAttrs, funcAttr);
541+
argTypes, args, paramAttrs, funcAttr, op.getOperation());
593542
if (std::optional<ArrayAttr> optCacheControls =
594543
getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
595544
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);

0 commit comments

Comments
 (0)