Skip to content

Commit fc8b2bf

Browse files
[MLIR][LLVM] Import dereferenceable metadata from LLVM IR (llvm#130974)
Add support for importing `dereferenceable` and `dereferenceable_or_null` metadata into LLVM dialect. Add a new attribute which models these two metadata nodes and a new OpInterface.
1 parent bddf24d commit fc8b2bf

File tree

15 files changed

+261
-3
lines changed

15 files changed

+261
-3
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,4 +1267,28 @@ def WorkgroupAttributionAttr
12671267
let assemblyFormat = "`<` $num_elements `,` $element_type `>`";
12681268
}
12691269

1270+
//===----------------------------------------------------------------------===//
1271+
// DereferenceableAttr
1272+
//===----------------------------------------------------------------------===//
1273+
1274+
def LLVM_DereferenceableAttr : LLVM_Attr<"Dereferenceable", "dereferenceable"> {
1275+
let summary = "LLVM dereferenceable attribute";
1276+
let description = [{
1277+
Defines `dereferenceable` or `dereferenceable_or_null` metadata that can
1278+
be set via the `DereferenceableOpInterface` on an `inttoptr` operation or
1279+
on a `load` operation which loads a pointer. The attribute is used to
1280+
denote that the result of these operations is dereferenceable up to a
1281+
certain number of bytes, represented by `$bytes`. The optional `$mayBeNull`
1282+
parameter is set to true if the attribute defines `dereferenceable_or_null`
1283+
metadata.
1284+
1285+
See the following links for more details:
1286+
https://llvm.org/docs/LangRef.html#dereferenceable-metadata
1287+
https://llvm.org/docs/LangRef.html#dereferenceable-or-null-metadata
1288+
}];
1289+
let parameters = (ins "uint64_t":$bytes,
1290+
DefaultValuedParameter<"bool", "false">:$mayBeNull);
1291+
let assemblyFormat = "`<` struct(params) `>`";
1292+
}
1293+
12701294
#endif // LLVMIR_ATTRDEFS

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ LogicalResult verifyAccessGroupOpInterface(Operation *op);
2727
/// the alias analysis interface.
2828
LogicalResult verifyAliasAnalysisOpInterface(Operation *op);
2929

30+
/// Verifies that the operation implementing the dereferenceable interface has
31+
/// exactly one result of LLVM pointer type.
32+
LogicalResult verifyDereferenceableOpInterface(Operation *op);
33+
3034
} // namespace detail
3135
} // namespace LLVM
3236
} // namespace mlir

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,43 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
330330
];
331331
}
332332

333+
def DereferenceableOpInterface : OpInterface<"DereferenceableOpInterface"> {
334+
let description = [{
335+
An interface for memory operations that can carry dereferenceable metadata.
336+
It provides setters and getters for the operation's dereferenceable
337+
attributes. The default implementations of the interface methods expect
338+
the operation to have an attribute of type DereferenceableAttr.
339+
}];
340+
341+
let cppNamespace = "::mlir::LLVM";
342+
let verify = [{ return detail::verifyDereferenceableOpInterface($_op); }];
343+
344+
let methods = [
345+
InterfaceMethod<
346+
/*desc=*/ "Returns the dereferenceable attribute or nullptr",
347+
/*returnType=*/ "::mlir::LLVM::DereferenceableAttr",
348+
/*methodName=*/ "getDereferenceableOrNull",
349+
/*args=*/ (ins),
350+
/*methodBody=*/ [{}],
351+
/*defaultImpl=*/ [{
352+
auto op = cast<ConcreteOp>(this->getOperation());
353+
return op.getDereferenceableAttr();
354+
}]
355+
>,
356+
InterfaceMethod<
357+
/*desc=*/ "Sets the dereferenceable attribute",
358+
/*returnType=*/ "void",
359+
/*methodName=*/ "setDereferenceable",
360+
/*args=*/ (ins "::mlir::LLVM::DereferenceableAttr":$attr),
361+
/*methodBody=*/ [{}],
362+
/*defaultImpl=*/ [{
363+
auto op = cast<ConcreteOp>(this->getOperation());
364+
op.setDereferenceableAttr(attr);
365+
}]
366+
>
367+
];
368+
}
369+
333370
def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
334371
let description = [{
335372
An interface for operations receiving an exception behavior attribute

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
364364
[DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
365365
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
366366
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
367-
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
367+
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>,
368+
DeclareOpInterfaceMethods<DereferenceableOpInterface>]> {
368369
dag args = (ins LLVM_AnyPointer:$addr,
369370
OptionalAttr<I64Attr>:$alignment,
370371
UnitAttr:$volatile_,
@@ -373,7 +374,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
373374
UnitAttr:$invariantGroup,
374375
DefaultValuedAttr<
375376
AtomicOrdering, "AtomicOrdering::not_atomic">:$ordering,
376-
OptionalAttr<StrAttr>:$syncscope);
377+
OptionalAttr<StrAttr>:$syncscope,
378+
OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
377379
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
378380
let arguments = !con(args, aliasAttrs);
379381
let results = (outs LLVM_LoadableType:$res);
@@ -407,6 +409,7 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
407409
(`atomic` (`syncscope` `(` $syncscope^ `)`)? $ordering^)?
408410
(`invariant` $invariant^)?
409411
(`invariant_group` $invariantGroup^)?
412+
(`dereferenceable` `` $dereferenceable^)?
410413
attr-dict `:` qualified(type($addr)) `->` type($res)
411414
}];
412415
string llvmBuilder = [{
@@ -416,6 +419,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
416419
llvm::MDNode *metadata = llvm::MDNode::get(inst->getContext(), std::nullopt);
417420
inst->setMetadata(llvm::LLVMContext::MD_invariant_load, metadata);
418421
}
422+
if ($dereferenceable)
423+
moduleTranslation.setDereferenceableMetadata(op, inst);
419424
}] # setOrderingCode
420425
# setSyncScopeCode
421426
# setAlignmentCode
@@ -571,6 +576,29 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
571576
}];
572577
}
573578

579+
class LLVM_DereferenceableCastOp<string mnemonic, string instName, Type type,
580+
Type resultType, list<Trait> traits = []> :
581+
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<DereferenceableOpInterface>], traits)> {
582+
let arguments = (ins type:$arg, OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
583+
let results = (outs resultType:$res);
584+
let builders = [LLVM_OneResultOpBuilder];
585+
let assemblyFormat = "$arg (`dereferenceable` `` $dereferenceable^)? attr-dict `:` type($arg) `to` type($res)";
586+
string llvmInstName = instName;
587+
string llvmBuilder = [{
588+
auto *val = builder.Create}] # instName # [{($arg, $_resultType);
589+
$res = val;
590+
if ($dereferenceable) {
591+
llvm::Instruction *inst = dyn_cast<llvm::Instruction>(val);
592+
moduleTranslation.setDereferenceableMetadata(op, inst);
593+
}
594+
}];
595+
string mlirBuilder = [{
596+
auto op = $_builder.create<$_qualCppClassName>(
597+
$_location, $_resultType, $arg);
598+
$res = op;
599+
}];
600+
}
601+
574602
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
575603
LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
576604
let hasFolder = 1;
@@ -583,7 +611,7 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "AddrSpaceCast",
583611
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
584612
let hasFolder = 1;
585613
}
586-
def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "IntToPtr",
614+
def LLVM_IntToPtrOp : LLVM_DereferenceableCastOp<"inttoptr", "IntToPtr",
587615
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
588616
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
589617
def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt",

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ class ModuleImport {
248248
LoopAnnotationAttr translateLoopAnnotationAttr(const llvm::MDNode *node,
249249
Location loc) const;
250250

251+
/// Returns the dereferenceable attribute that corresponds to the given LLVM
252+
/// dereferenceable or dereferenceable_or_null metadata `node`. `kindID`
253+
/// specifies the kind of the metadata node (dereferenceable or
254+
/// dereferenceable_or_null).
255+
FailureOr<DereferenceableAttr>
256+
translateDereferenceableAttr(const llvm::MDNode *node, unsigned kindID);
257+
251258
/// Returns the alias scope attributes that map to the alias scope nodes
252259
/// starting from the metadata `node`. Returns failure, if any of the
253260
/// attributes cannot be found.

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ class ModuleTranslation {
161161
/// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
162162
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);
163163

164+
/// Sets LLVM dereferenceable metadata for operations that have
165+
/// dereferenceable attributes.
166+
void setDereferenceableMetadata(DereferenceableOpInterface op,
167+
llvm::Instruction *inst);
168+
164169
/// Sets LLVM profiling metadata for operations that have branch weights.
165170
void setBranchWeightsMetadata(BranchWeightOpInterface op);
166171

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
940940
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
941941
isNonTemporal, isInvariant, isInvariantGroup, ordering,
942942
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
943+
/*dereferenceable=*/nullptr,
943944
/*access_groups=*/nullptr,
944945
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
945946
/*tbaa=*/nullptr);

mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
6262
return isArrayOf<TBAATagAttr>(op, tags);
6363
}
6464

65+
//===----------------------------------------------------------------------===//
66+
// DereferenceableOpInterface
67+
//===----------------------------------------------------------------------===//
68+
69+
LogicalResult
70+
mlir::LLVM::detail::verifyDereferenceableOpInterface(Operation *op) {
71+
auto iface = cast<DereferenceableOpInterface>(op);
72+
73+
if (auto derefAttr = iface.getDereferenceableOrNull())
74+
if (op->getNumResults() != 1 ||
75+
!mlir::isa<LLVMPointerType>(op->getResult(0).getType()))
76+
return op->emitOpError(
77+
"expected op to return a single LLVM pointer type");
78+
79+
return success();
80+
}
81+
6582
SmallVector<Value> mlir::LLVM::AtomicCmpXchgOp::getAccessedOperands() {
6683
return {getPtr()};
6784
}

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
9090
llvm::LLVMContext::MD_loop,
9191
llvm::LLVMContext::MD_noalias,
9292
llvm::LLVMContext::MD_alias_scope,
93+
llvm::LLVMContext::MD_dereferenceable,
94+
llvm::LLVMContext::MD_dereferenceable_or_null,
9395
context.getMDKindID(vecTypeHintMDName),
9496
context.getMDKindID(workGroupSizeHintMDName),
9597
context.getMDKindID(reqdWorkGroupSizeMDName),
@@ -188,6 +190,25 @@ static LogicalResult setAccessGroupsAttr(const llvm::MDNode *node,
188190
return success();
189191
}
190192

193+
/// Converts the given dereferenceable metadata node to a dereferenceable
194+
/// attribute, and attaches it to the imported operation if the translation
195+
/// succeeds. Returns failure if the LLVM IR metadata node is ill-formed.
196+
static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
197+
unsigned kindID, Operation *op,
198+
LLVM::ModuleImport &moduleImport) {
199+
auto dereferenceable =
200+
moduleImport.translateDereferenceableAttr(node, kindID);
201+
if (failed(dereferenceable))
202+
return failure();
203+
204+
auto iface = dyn_cast<DereferenceableOpInterface>(op);
205+
if (!iface)
206+
return failure();
207+
208+
iface.setDereferenceable(*dereferenceable);
209+
return success();
210+
}
211+
191212
/// Converts the given loop metadata node to an MLIR loop annotation attribute
192213
/// and attaches it to the imported operation if the translation succeeds.
193214
/// Returns failure otherwise.
@@ -401,6 +422,13 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
401422
return setAliasScopesAttr(node, op, moduleImport);
402423
if (kind == llvm::LLVMContext::MD_noalias)
403424
return setNoaliasScopesAttr(node, op, moduleImport);
425+
if (kind == llvm::LLVMContext::MD_dereferenceable)
426+
return setDereferenceableAttr(node, llvm::LLVMContext::MD_dereferenceable,
427+
op, moduleImport);
428+
if (kind == llvm::LLVMContext::MD_dereferenceable_or_null)
429+
return setDereferenceableAttr(
430+
node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
431+
moduleImport);
404432

405433
llvm::LLVMContext &context = node->getContext();
406434
if (kind == context.getMDKindID(vecTypeHintMDName))

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2527,6 +2527,31 @@ ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
25272527
return loopAnnotationImporter->translateLoopAnnotation(node, loc);
25282528
}
25292529

2530+
FailureOr<DereferenceableAttr>
2531+
ModuleImport::translateDereferenceableAttr(const llvm::MDNode *node,
2532+
unsigned kindID) {
2533+
Location loc = mlirModule.getLoc();
2534+
2535+
// The only operand should be a constant integer representing the number of
2536+
// dereferenceable bytes.
2537+
if (node->getNumOperands() != 1)
2538+
return emitError(loc) << "dereferenceable metadata must have one operand: "
2539+
<< diagMD(node, llvmModule.get());
2540+
2541+
auto *numBytesMD = dyn_cast<llvm::ConstantAsMetadata>(node->getOperand(0));
2542+
auto *numBytesCst = dyn_cast<llvm::ConstantInt>(numBytesMD->getValue());
2543+
if (!numBytesCst || !numBytesCst->getValue().isNonNegative())
2544+
return emitError(loc) << "dereferenceable metadata operand must be a "
2545+
"non-negative constant integer: "
2546+
<< diagMD(node, llvmModule.get());
2547+
2548+
bool mayBeNull = kindID == llvm::LLVMContext::MD_dereferenceable_or_null;
2549+
auto derefAttr = builder.getAttr<DereferenceableAttr>(
2550+
numBytesCst->getZExtValue(), mayBeNull);
2551+
2552+
return derefAttr;
2553+
}
2554+
25302555
OwningOpRef<ModuleOp>
25312556
mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
25322557
MLIRContext *context, bool emitExpensiveWarnings,

0 commit comments

Comments
 (0)