23
23
24
24
#include " llvm/ADT/TypeSwitch.h"
25
25
26
- #define DEBUG_TYPE " xevm-to-llvm"
27
-
28
26
namespace mlir {
29
27
#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30
28
#include " mlir/Conversion/Passes.h.inc"
@@ -70,6 +68,9 @@ std::string getTypeMangling(Type ty, bool isUnsigned = false) {
70
68
default :
71
69
llvm_unreachable (" unhandled integer type" );
72
70
}
71
+ })
72
+ .Default ([](Type) -> std::string {
73
+ llvm_unreachable (" unhandled type for mangling" );
73
74
});
74
75
}
75
76
@@ -165,38 +166,18 @@ int32_t getL3CacheControl(OpType op) {
165
166
if constexpr (isLoad) {
166
167
switch (*op.getCacheControl ()) {
167
168
case LoadCacheControl::L1UC_L2UC_L3UC:
168
- control = 1 ;
169
- break ;
170
- case LoadCacheControl::L1UC_L2UC_L3C:
171
- control = 2 ;
172
- break ;
173
169
case LoadCacheControl::L1UC_L2C_L3UC:
174
- control = 1 ;
175
- break ;
176
- case LoadCacheControl::L1UC_L2C_L3C:
177
- control = 2 ;
178
- break ;
179
170
case LoadCacheControl::L1C_L2UC_L3UC:
180
- control = 1 ;
181
- break ;
182
- case LoadCacheControl::L1C_L2UC_L3C:
183
- control = 2 ;
184
- break ;
185
171
case LoadCacheControl::L1C_L2C_L3UC:
186
- control = 1 ;
187
- break ;
188
- case LoadCacheControl::L1C_L2C_L3C:
189
- control = 2 ;
190
- break ;
191
172
case LoadCacheControl::L1S_L2UC_L3UC:
192
- control = 1 ;
193
- break ;
194
- case LoadCacheControl::L1S_L2UC_L3C:
195
- control = 2 ;
196
- break ;
197
173
case LoadCacheControl::L1S_L2C_L3UC:
198
174
control = 1 ;
199
175
break ;
176
+ case LoadCacheControl::L1UC_L2UC_L3C:
177
+ case LoadCacheControl::L1UC_L2C_L3C:
178
+ case LoadCacheControl::L1C_L2UC_L3C:
179
+ case LoadCacheControl::L1C_L2C_L3C:
180
+ case LoadCacheControl::L1S_L2UC_L3C:
200
181
case LoadCacheControl::L1S_L2C_L3C:
201
182
control = 2 ;
202
183
break ;
@@ -209,47 +190,21 @@ int32_t getL3CacheControl(OpType op) {
209
190
} else {
210
191
switch (*op.getCacheControl ()) {
211
192
case StoreCacheControl::L1UC_L2UC_L3UC:
212
- control = 1 ;
213
- break ;
214
- case StoreCacheControl::L1UC_L2UC_L3WB:
215
- control = 2 ;
216
- break ;
217
193
case StoreCacheControl::L1UC_L2WB_L3UC:
218
- control = 1 ;
219
- break ;
220
- case StoreCacheControl::L1UC_L2WB_L3WB:
221
- control = 2 ;
222
- break ;
223
194
case StoreCacheControl::L1WT_L2UC_L3UC:
224
- control = 1 ;
225
- break ;
226
- case StoreCacheControl::L1WT_L2UC_L3WB:
227
- control = 2 ;
228
- break ;
229
195
case StoreCacheControl::L1WT_L2WB_L3UC:
230
- control = 1 ;
231
- break ;
232
- case StoreCacheControl::L1WT_L2WB_L3WB:
233
- control = 2 ;
234
- break ;
235
196
case StoreCacheControl::L1S_L2UC_L3UC:
236
- control = 1 ;
237
- break ;
238
- case StoreCacheControl::L1S_L2UC_L3WB:
239
- control = 2 ;
240
- break ;
241
197
case StoreCacheControl::L1S_L2WB_L3UC:
242
- control = 1 ;
243
- break ;
244
- case StoreCacheControl::L1S_L2WB_L3WB:
245
- control = 2 ;
246
- break ;
247
198
case StoreCacheControl::L1WB_L2UC_L3UC:
248
- control = 1 ;
249
- break ;
250
199
case StoreCacheControl::L1WB_L2WB_L3UC:
251
200
control = 1 ;
252
201
break ;
202
+ case StoreCacheControl::L1UC_L2UC_L3WB:
203
+ case StoreCacheControl::L1UC_L2WB_L3WB:
204
+ case StoreCacheControl::L1WT_L2UC_L3WB:
205
+ case StoreCacheControl::L1WT_L2WB_L3WB:
206
+ case StoreCacheControl::L1S_L2UC_L3WB:
207
+ case StoreCacheControl::L1S_L2WB_L3WB:
253
208
case StoreCacheControl::L1WB_L2UC_L3WB:
254
209
control = 2 ;
255
210
break ;
@@ -263,13 +218,8 @@ int32_t getL3CacheControl(OpType op) {
263
218
template <bool isLoad, typename OpType>
264
219
static std::optional<ArrayAttr>
265
220
getCacheControlMetadata (ConversionPatternRewriter &rewriter, OpType op) {
266
- if constexpr (isLoad) {
267
- if (!op.getCacheControl ())
268
- return {};
269
- } else {
270
- if (!op.getCacheControl ())
271
- return {};
272
- }
221
+ if (!op.getCacheControl ())
222
+ return {};
273
223
constexpr int32_t decorationCacheControlArity{4 };
274
224
constexpr int32_t loadCacheControlKey{6442 };
275
225
constexpr int32_t storeCacheControlKey{6443 };
@@ -289,13 +239,12 @@ static LLVM::CallOp createDeviceFunctionCall(
289
239
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
290
240
ArrayRef<Type> argTypes, ArrayRef<Value> args,
291
241
mlir::ArrayRef<std::pair<unsigned , mlir::StringRef>> paramAttrs,
292
- LLVMFuncAttributeOptions funcAttributeOptions) {
242
+ LLVMFuncAttributeOptions funcAttributeOptions, Operation *op ) {
293
243
auto moduleOp = rewriter.getBlock ()
294
244
->getParentOp ()
295
245
->getParentWithTrait <OpTrait::SymbolTable>();
296
246
assert (moduleOp && " Expecting module" );
297
- MLIRContext *ctx = rewriter.getContext ();
298
- Location loc = UnknownLoc::get (ctx);
247
+ Location loc = op->getLoc ();
299
248
300
249
auto funcOpRes =
301
250
LLVM::lookupOrCreateFn (rewriter, moduleOp, funcName, argTypes, retType);
@@ -384,9 +333,10 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
384
333
/* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
385
334
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
386
335
funcAttrs.memEffectsAttr = memAttr;
387
- Value result = createDeviceFunctionCall (rewriter, fnName, cTy, argTypes,
388
- args, {}, funcAttrs)
389
- ->getResult (0 );
336
+ Value result =
337
+ createDeviceFunctionCall (rewriter, fnName, cTy, argTypes, args, {},
338
+ funcAttrs, op.getOperation ())
339
+ ->getResult (0 );
390
340
391
341
if (cOrigTy != cTy)
392
342
result = rewriter.create <LLVM::BitcastOp>(loc, cOrigTy, result);
@@ -419,8 +369,8 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
419
369
ConversionPatternRewriter &rewriter) const override {
420
370
auto loc = op.getLoc ();
421
371
const std::string fnName{" _Z8prefetchPU3AS1Kcm" };
422
- Value one = rewriter. create <LLVM::ConstantOp>(
423
- loc, rewriter.getI64Type (), rewriter. getI64IntegerAttr ( 1 ) );
372
+ Value one =
373
+ rewriter. create <LLVM::ConstantOp>( loc, rewriter.getI64Type (), 1 );
424
374
SmallVector<Value> args{op.getPtr (), one};
425
375
SmallVector<Type> argTypes;
426
376
for (auto arg : args)
@@ -434,7 +384,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
434
384
435
385
LLVM::CallOp call = createDeviceFunctionCall (
436
386
rewriter, fnName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
437
- argTypes, args, {}, funcAttr);
387
+ argTypes, args, {}, funcAttr, op. getOperation () );
438
388
if (std::optional<ArrayAttr> optCacheControls =
439
389
getCacheControlMetadata<true >(rewriter, op))
440
390
call->setAttr (XeVMDialect::getCacheControlsAttrName (), *optCacheControls);
@@ -473,17 +423,18 @@ class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
473
423
// CLUSTER and SYSTEM are not supported in OpenCL
474
424
llvm_unreachable (" unsupported xevm::MemoryScope" );
475
425
}
476
- Value acqRel = rewriter.create <LLVM::ConstantOp>(
477
- loc, rewriter.getI32Type (), rewriter. getI32IntegerAttr ( 4 ) );
478
- Value memScopeConst = rewriter. create <LLVM::ConstantOp>(
479
- loc, rewriter.getI32Type (), rewriter. getI32IntegerAttr ( memScope) );
480
- Value addrSpaceConst = rewriter. create <LLVM::ConstantOp>(
481
- loc, rewriter.getI32Type (), rewriter. getI32IntegerAttr ( addrSpace) );
426
+ Type i32Type = rewriter.getI32Type ();
427
+ Value acqRel = rewriter.create <LLVM::ConstantOp>(loc, i32Type, 4 );
428
+ Value memScopeConst =
429
+ rewriter.create <LLVM::ConstantOp>(loc, i32Type, memScope);
430
+ Value addrSpaceConst =
431
+ rewriter.create <LLVM::ConstantOp>(loc, i32Type, addrSpace);
482
432
SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
483
- SmallVector<Type> argTypes{3 , rewriter. getI32Type () };
433
+ SmallVector<Type> argTypes{3 , i32Type };
484
434
createDeviceFunctionCall (rewriter, mangle (fnName, argTypes),
485
435
LLVM::LLVMVoidType::get (rewriter.getContext ()),
486
- argTypes, args, {}, noUnwindAttrs);
436
+ argTypes, args, {}, noUnwindAttrs,
437
+ op.getOperation ());
487
438
rewriter.eraseOp (op);
488
439
return success ();
489
440
}
@@ -512,10 +463,8 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
512
463
auto i32Type = rewriter.getI32Type ();
513
464
Value byteCoord =
514
465
rewriter.create <LLVM::UndefOp>(loc, VectorType::get (2 , i32Type));
515
- Value zero = rewriter.create <LLVM::ConstantOp>(
516
- loc, i32Type, rewriter.getI32IntegerAttr (0 ));
517
- Value one = rewriter.create <LLVM::ConstantOp>(
518
- loc, i32Type, rewriter.getI32IntegerAttr (1 ));
466
+ Value zero = rewriter.create <LLVM::ConstantOp>(loc, i32Type, 0 );
467
+ Value one = rewriter.create <LLVM::ConstantOp>(loc, i32Type, 1 );
519
468
byteCoord = rewriter.create <LLVM::InsertElementOp>(
520
469
loc, VectorType::get (2 , i32Type), byteCoord, op.getX (), zero);
521
470
byteCoord = rewriter.create <LLVM::InsertElementOp>(
@@ -589,7 +538,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
589
538
}
590
539
LLVM::CallOp call = createDeviceFunctionCall (
591
540
rewriter, funcName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
592
- argTypes, args, paramAttrs, funcAttr);
541
+ argTypes, args, paramAttrs, funcAttr, op. getOperation () );
593
542
if (std::optional<ArrayAttr> optCacheControls =
594
543
getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
595
544
call->setAttr (XeVMDialect::getCacheControlsAttrName (), *optCacheControls);
0 commit comments