diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 63e007cdc335c..e355bb8f5ddae 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -223,6 +223,9 @@ Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, /// function confirms that the Operation has the desired properties. bool satisfiesLLVMModule(Operation *op); +/// Lookup parent Module satisfying LLVM conditions on the Module Operation. +Operation *parentLLVMModule(Operation *op); + /// Convert an array of integer attributes to a vector of integers that can be /// used as indices in LLVM operations. template diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index f4c1640098320..fd3a2be29a242 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1285,6 +1285,10 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", /// Return the llvm.mlir.alias operation that defined the value referenced /// here. AliasOp getAlias(SymbolTableCollection &symbolTable); + + /// Return the llvm.mlir.ifunc operation that defined the value referenced + /// here. + IFuncOp getIFunc(SymbolTableCollection &symbolTable); }]; let assemblyFormat = "$global_name attr-dict `:` qualified(type($res))"; @@ -1601,6 +1605,67 @@ def LLVM_AliasOp : LLVM_Op<"mlir.alias", let hasRegionVerifier = 1; } +def LLVM_IFuncOp : LLVM_Op<"mlir.ifunc", + [IsolatedFromAbove, Symbol, DeclareOpInterfaceMethods]> { + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttr:$i_func_type, + FlatSymbolRefAttr:$resolver, + TypeAttr:$resolver_type, + Linkage:$linkage, + UnitAttr:$dso_local, + DefaultValuedAttr, "0">:$address_space, + DefaultValuedAttr:$unnamed_addr, + DefaultValuedAttr:$visibility_ + ); + let summary = "LLVM dialect ifunc"; + let description = [{ + `llvm.mlir.ifunc` is a top level operation that defines a global ifunc. + It defines a new symbol and takes a symbol refering to a resolver function. + IFuncs can be called as regular functions. The function type is the same + as the IFuncType. The symbol is resolved at runtime by calling a resolver + function. + + Examples: + + ```mlir + // IFuncs have @-identifier and use a resolver function. + llvm.mlir.ifunc external @foo: !llvm.func, !llvm.ptr @resolver + + llvm.func @foo_1(i64) -> f32 + llvm.func @foo_2(i64) -> f32 + + llvm.func @resolve_foo() -> !llvm.ptr attributes { + %0 = llvm.mlir.addressof @foo_2 : !llvm.ptr + %1 = llvm.mlir.addressof @foo_1 : !llvm.ptr + + // ... Logic selecting from foo_{1, 2} + + // Return function pointer to the selected function + llvm.return %7 : !llvm.ptr + } + + llvm.func @use_foo() { + // IFuncs are called as regular functions + %res = llvm.call @foo(%value) : i64 -> f32 + } + ``` + }]; + + let builders = [ + OpBuilder<(ins "StringRef":$name, "Type":$i_func_type, + "StringRef":$resolver, "Type":$resolver_type, + "Linkage":$linkage, "LLVM::Visibility":$visibility)> + ]; + + let assemblyFormat = [{ + custom($linkage) ($visibility_^)? ($unnamed_addr^)? + $sym_name `:` $i_func_type `,` $resolver_type $resolver attr-dict + }]; + let hasVerifier = 1; +} + + def LLVM_DSOLocalEquivalentOp : LLVM_Op<"dso_local_equivalent", [Pure, ConstantLike, DeclareOpInterfaceMethods]> { let arguments = (ins FlatSymbolRefAttr:$function_name); diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 9902c6bb15caf..886c6df009d39 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -71,6 +71,9 @@ class ModuleImport { /// Converts all aliases of the LLVM module to MLIR variables. LogicalResult convertAliases(); + /// Converts all ifuncs of the LLVM module to MLIR variables. + LogicalResult convertIFuncs(); + /// Converts the data layout of the LLVM module to an MLIR data layout /// specification. LogicalResult convertDataLayout(); @@ -320,6 +323,8 @@ class ModuleImport { /// Converts an LLVM global alias variable into an MLIR LLVM dialect alias /// operation if a conversion exists. Otherwise, returns failure. LogicalResult convertAlias(llvm::GlobalAlias *alias); + // Converts an LLVM global ifunc into an MLIR LLVM dialect ifunc operation + LogicalResult convertIFunc(llvm::GlobalIFunc *ifunc); /// Returns personality of `func` as a FlatSymbolRefAttr. FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func); /// Imports `bb` into `block`, which must be initially empty. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 79e8bb6add0da..515eac41c7193 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -223,6 +223,12 @@ class ModuleTranslation { return aliasesMapping.lookup(op); } + /// Finds an LLVM IR global value that corresponds to the given MLIR operation + /// defining an IFunc. + llvm::GlobalValue *lookupIFunc(Operation *op) { + return ifuncMapping.lookup(op); + } + /// Returns the OpenMP IR builder associated with the LLVM IR module being /// constructed. llvm::OpenMPIRBuilder *getOpenMPBuilder(); @@ -308,6 +314,7 @@ class ModuleTranslation { bool recordInsertions = false); LogicalResult convertFunctionSignatures(); LogicalResult convertFunctions(); + LogicalResult convertIFuncs(); LogicalResult convertComdats(); LogicalResult convertUnresolvedBlockAddress(); @@ -369,6 +376,10 @@ class ModuleTranslation { /// aliases. DenseMap aliasesMapping; + /// Mappings between llvm.mlir.ifunc definitions and corresponding global + /// ifuncs. + DenseMap ifuncMapping; + /// A stateful object used to translate types. TypeToLLVMIRTranslator typeTranslator; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 6dcd94e6eea17..b3bf0c4dfc322 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -139,6 +139,17 @@ static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, return static_cast(index); } +static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val) { + p << stringifyLinkage(val.getLinkage()); +} + +static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) { + val = LinkageAttr::get( + p.getContext(), + parseOptionalLLVMKeyword(p, LLVM::Linkage::External)); + return success(); +} + //===----------------------------------------------------------------------===// // Operand bundle helpers. //===----------------------------------------------------------------------===// @@ -1175,14 +1186,17 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return emitOpError() << "'" << calleeName.getValue() << "' does not reference a symbol in the current scope"; - auto fn = dyn_cast(callee); - if (!fn) - return emitOpError() << "'" << calleeName.getValue() - << "' does not reference a valid LLVM function"; - - if (failed(verifyCallOpDebugInfo(*this, fn))) - return failure(); - fnType = fn.getFunctionType(); + if (auto fn = dyn_cast(callee)) { + if (failed(verifyCallOpDebugInfo(*this, fn))) + return failure(); + fnType = fn.getFunctionType(); + } else if (auto ifunc = dyn_cast(callee)) { + fnType = ifunc.getIFuncType(); + } else { + return emitOpError() + << "'" << calleeName.getValue() + << "' does not reference a valid LLVM function or IFunc"; + } } LLVMFunctionType funcType = llvm::dyn_cast(fnType); @@ -2038,14 +2052,6 @@ LogicalResult ReturnOp::verify() { // LLVM::AddressOfOp. //===----------------------------------------------------------------------===// -static Operation *parentLLVMModule(Operation *op) { - Operation *module = op->getParentOp(); - while (module && !satisfiesLLVMModule(module)) - module = module->getParentOp(); - assert(module && "unexpected operation outside of a module"); - return module; -} - GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) { return dyn_cast_or_null( symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); @@ -2061,6 +2067,11 @@ AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) { symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); } +IFuncOp AddressOfOp::getIFunc(SymbolTableCollection &symbolTable) { + return dyn_cast_or_null( + symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); +} + LogicalResult AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) { Operation *symbol = @@ -2069,10 +2080,11 @@ AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto global = dyn_cast_or_null(symbol); auto function = dyn_cast_or_null(symbol); auto alias = dyn_cast_or_null(symbol); + auto ifunc = dyn_cast_or_null(symbol); - if (!global && !function && !alias) + if (!global && !function && !alias && !ifunc) return emitOpError("must reference a global defined by 'llvm.mlir.global', " - "'llvm.mlir.alias' or 'llvm.func'"); + "'llvm.mlir.alias' or 'llvm.func' or 'llvm.mlir.ifunc'"); LLVMPointerType type = getType(); if ((global && global.getAddrSpace() != type.getAddressSpace()) || @@ -2682,6 +2694,59 @@ unsigned AliasOp::getAddrSpace() { return ptrTy.getAddressSpace(); } +//===----------------------------------------------------------------------===// +// IFuncOp +//===----------------------------------------------------------------------===// + +void IFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name, + Type iFuncType, StringRef resolverName, Type resolverType, + Linkage linkage, LLVM::Visibility visibility) { + return build(builder, result, name, iFuncType, resolverName, resolverType, + linkage, /*dso_local=*/false, /*address_space=*/0, + UnnamedAddr::None, visibility); +} + +LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + Operation *symbol = + symbolTable.lookupSymbolIn(parentLLVMModule(*this), getResolverAttr()); + auto resolver = dyn_cast(symbol); + if (!resolver) { + // FIXME: Strip aliases to find the called function + if (isa(symbol)) + return success(); + return emitOpError("must have a function resolver"); + } + // This matches LLVM IR verification logic, see from llvm/lib/IR/Verifier.cpp + Linkage linkage = resolver.getLinkage(); + if (resolver.isExternal() || linkage == Linkage::AvailableExternally) + return emitOpError("resolver must be a definition"); + if (!isa(resolver.getFunctionType().getReturnType())) + return emitOpError("resolver must return a pointer"); + auto resolverPtr = dyn_cast(getResolverType()); + if (!resolverPtr || resolverPtr.getAddressSpace() != getAddressSpace()) + return emitOpError("resolver has incorrect type"); + return success(); +} + +LogicalResult IFuncOp::verify() { + switch (getLinkage()) { + case Linkage::External: + case Linkage::Internal: + case Linkage::Private: + case Linkage::Weak: + case Linkage::WeakODR: + case Linkage::Linkonce: + case Linkage::LinkonceODR: + break; + default: + return emitOpError() << "'" << stringifyLinkage(getLinkage()) + << "' linkage not supported in ifuncs, available " + "options: private, internal, linkonce, weak, " + "linkonce_odr, weak_odr, or external linkage"; + } + return success(); +} + //===----------------------------------------------------------------------===// // ShuffleVectorOp //===----------------------------------------------------------------------===// @@ -4329,3 +4394,11 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { return op->hasTrait() && op->hasTrait(); } + +Operation *mlir::LLVM::parentLLVMModule(Operation *op) { + Operation *module = op->getParentOp(); + while (module && !satisfiesLLVMModule(module)) + module = module->getParentOp(); + assert(module && "unexpected operation outside of a module"); + return module; +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 70029d7e15a90..ff34a0825215c 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -422,9 +422,18 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, ArrayRef operandsRef(operands); llvm::CallInst *call; if (auto attr = callOp.getCalleeAttr()) { - call = - builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()), - operandsRef, opBundles); + if (llvm::Function *function = + moduleTranslation.lookupFunction(attr.getValue())) { + call = builder.CreateCall(function, operandsRef, opBundles); + } else { + Operation *moduleOp = parentLLVMModule(&opInst); + Operation *ifuncOp = + moduleTranslation.symbolTable().lookupSymbolIn(moduleOp, attr); + llvm::GlobalValue *ifunc = moduleTranslation.lookupIFunc(ifuncOp); + llvm::FunctionType *calleeType = llvm::cast( + moduleTranslation.convertType(callOp.getCalleeFunctionType())); + call = builder.CreateCall(calleeType, ifunc, operandsRef, opBundles); + } } else { llvm::FunctionType *calleeType = llvm::cast( moduleTranslation.convertType(callOp.getCalleeFunctionType())); @@ -648,18 +657,21 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::LLVMFuncOp function = addressOfOp.getFunction(moduleTranslation.symbolTable()); LLVM::AliasOp alias = addressOfOp.getAlias(moduleTranslation.symbolTable()); + LLVM::IFuncOp ifunc = addressOfOp.getIFunc(moduleTranslation.symbolTable()); // The verifier should not have allowed this. - assert((global || function || alias) && - "referencing an undefined global, function, or alias"); + assert((global || function || alias || ifunc) && + "referencing an undefined global, function, alias, or ifunc"); llvm::Value *llvmValue = nullptr; if (global) llvmValue = moduleTranslation.lookupGlobal(global); else if (alias) llvmValue = moduleTranslation.lookupAlias(alias); - else + else if (function) llvmValue = moduleTranslation.lookupFunction(function.getName()); + else + llvmValue = moduleTranslation.lookupIFunc(ifunc); moduleTranslation.mapValue(addressOfOp.getResult(), llvmValue); return success(); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index bfda223fe0f5f..c807985756539 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1031,6 +1031,16 @@ LogicalResult ModuleImport::convertAliases() { return success(); } +LogicalResult ModuleImport::convertIFuncs() { + for (llvm::GlobalIFunc &ifunc : llvmModule->ifuncs()) { + if (failed(convertIFunc(&ifunc))) { + return emitError(UnknownLoc::get(context)) + << "unhandled global ifunc: " << diag(ifunc); + } + } + return success(); +} + LogicalResult ModuleImport::convertDataLayout() { Location loc = mlirModule.getLoc(); DataLayoutImporter dataLayoutImporter(context, llvmModule->getDataLayout()); @@ -1369,6 +1379,21 @@ LogicalResult ModuleImport::convertAlias(llvm::GlobalAlias *alias) { return success(); } +LogicalResult ModuleImport::convertIFunc(llvm::GlobalIFunc *ifunc) { + OpBuilder::InsertionGuard guard = setGlobalInsertionPoint(); + + Type type = convertType(ifunc->getValueType()); + llvm::Constant *resolver = ifunc->getResolver(); + Type resolverType = convertType(resolver->getType()); + builder.create(mlirModule.getLoc(), ifunc->getName(), type, + resolver->getName(), resolverType, + convertLinkageFromLLVM(ifunc->getLinkage()), + ifunc->isDSOLocal(), ifunc->getAddressSpace(), + convertUnnamedAddrFromLLVM(ifunc->getUnnamedAddr()), + convertVisibilityFromLLVM(ifunc->getVisibility())); + return success(); +} + LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) { // Insert the global after the last one or at the start of the module. OpBuilder::InsertionGuard guard = setGlobalInsertionPoint(); @@ -1973,8 +1998,9 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst, // treated as indirect calls to constant operands that need to be converted. // Skip the callee operand if it's inline assembly, as it's handled separately // in InlineAsmOp. - if (!isa(callInst->getCalledOperand()) && !isInlineAsm) { - FailureOr called = convertValue(callInst->getCalledOperand()); + llvm::Value *calleeOperand = callInst->getCalledOperand(); + if (!isa(calleeOperand) && !isInlineAsm) { + FailureOr called = convertValue(calleeOperand); if (failed(called)) return failure(); operands.push_back(*called); @@ -2035,12 +2061,20 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst, if (failed(callType)) return failure(); auto *callee = dyn_cast(calledOperand); + + llvm::FunctionType *origCalleeType = nullptr; + if (callee) { + origCalleeType = callee->getFunctionType(); + } else if (auto *ifunc = dyn_cast(calledOperand)) { + origCalleeType = cast(ifunc->getValueType()); + } + // For indirect calls, return the type of the call itself. - if (!callee) + if (!origCalleeType) return callType; FailureOr calleeType = - castOrFailure(convertType(callee->getFunctionType())); + castOrFailure(convertType(origCalleeType)); if (failed(calleeType)) return failure(); @@ -2059,8 +2093,8 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst, FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) { llvm::Value *calledOperand = callInst->getCalledOperand(); - if (auto *callee = dyn_cast(calledOperand)) - return SymbolRefAttr::get(context, callee->getName()); + if (isa(calledOperand)) + return SymbolRefAttr::get(context, calledOperand->getName()); return {}; } @@ -3162,6 +3196,8 @@ OwningOpRef mlir::translateLLVMIRToModule( return {}; if (failed(moduleImport.convertAliases())) return {}; + if (failed(moduleImport.convertIFuncs())) + return {}; moduleImport.convertTargetTriple(); return module; } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 8908703cc1368..165e06b021fd3 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -791,6 +791,8 @@ void ModuleTranslation::forgetMapping(Region ®ion) { globalsMapping.erase(&op); if (isa(op)) aliasesMapping.erase(&op); + if (isa(op)) + ifuncMapping.erase(&op); if (isa(op)) callMapping.erase(&op); llvm::append_range( @@ -1868,6 +1870,33 @@ LogicalResult ModuleTranslation::convertFunctions() { return success(); } +LogicalResult ModuleTranslation::convertIFuncs() { + for (auto op : getModuleBody(mlirModule).getOps()) { + llvm::Type *type = convertType(op.getIFuncType()); + llvm::GlobalValue::LinkageTypes linkage = + convertLinkageToLLVM(op.getLinkage()); + llvm::Constant *resolver; + if (auto *resolverFn = lookupFunction(op.getResolver())) { + resolver = dyn_cast(resolverFn); + } else { + Operation *aliasTrg = symbolTable().lookupSymbolIn(parentLLVMModule(op), + op.getResolverAttr()); + resolver = cast(lookupAlias(aliasTrg)); + } + + auto *ifunc = + llvm::GlobalIFunc::create(type, op.getAddressSpace(), linkage, + op.getSymName(), resolver, llvmModule.get()); + addRuntimePreemptionSpecifier(op.getDsoLocal(), ifunc); + ifunc->setUnnamedAddr(convertUnnamedAddrToLLVM(op.getUnnamedAddr())); + ifunc->setVisibility(convertVisibilityToLLVM(op.getVisibility_())); + + ifuncMapping.try_emplace(op, ifunc); + } + + return success(); +} + LogicalResult ModuleTranslation::convertComdats() { for (auto comdatOp : getModuleBody(mlirModule).getOps()) { for (auto selectorOp : comdatOp.getOps()) { @@ -2284,6 +2313,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, return nullptr; if (failed(translator.convertGlobalsAndAliases())) return nullptr; + if (failed(translator.convertIFuncs())) + return nullptr; if (failed(translator.createTBAAMetadata())) return nullptr; if (failed(translator.createIdentMetadata())) @@ -2296,7 +2327,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, // Convert other top-level operations if possible. for (Operation &o : getModuleBody(module).getOperations()) { if (!isa(&o) && + LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp, + LLVM::IFuncOp>(&o) && !o.hasTrait() && failed(translator.convertOperation(o, llvmBuilder))) { return nullptr; diff --git a/mlir/test/Dialect/LLVMIR/ifunc.mlir b/mlir/test/Dialect/LLVMIR/ifunc.mlir new file mode 100644 index 0000000000000..5df81e9553eee --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/ifunc.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-opt %s -split-input-file --verify-roundtrip | FileCheck %s + +// CHECK: llvm.mlir.ifunc external @ifunc : !llvm.func, !llvm.ptr @resolver +llvm.mlir.ifunc @ifunc : !llvm.func, !llvm.ptr @resolver +llvm.func @resolver() -> !llvm.ptr { + %0 = llvm.mlir.constant(333 : i64) : i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr + llvm.return %1 : !llvm.ptr +} + +// ----- + +// CHECK: llvm.mlir.ifunc linkonce_odr hidden @ifunc : !llvm.func, !llvm.ptr @resolver {dso_local} +llvm.mlir.ifunc linkonce_odr hidden @ifunc : !llvm.func, !llvm.ptr @resolver {dso_local} +llvm.func @resolver() -> !llvm.ptr { + %0 = llvm.mlir.constant(333 : i64) : i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr + llvm.return %1 : !llvm.ptr +} + +// ----- + +// CHECK: llvm.mlir.ifunc private @ifunc : !llvm.func, !llvm.ptr @resolver {dso_local} +llvm.mlir.ifunc private @ifunc : !llvm.func, !llvm.ptr @resolver {dso_local} +llvm.func @resolver() -> !llvm.ptr { + %0 = llvm.mlir.constant(333 : i64) : i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr + llvm.return %1 : !llvm.ptr +} + +// ----- + +// CHECK: llvm.mlir.ifunc weak @ifunc : !llvm.func, !llvm.ptr @resolver + +llvm.mlir.ifunc weak @ifunc : !llvm.func, !llvm.ptr @resolver +llvm.func @resolver() -> !llvm.ptr { + %0 = llvm.mlir.constant(333 : i64) : i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr + llvm.return %1 : !llvm.ptr +} + diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index bd1106e304c60..dc39145f446ff 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1931,3 +1931,20 @@ llvm.func @invalid_xevm_matrix_3(%a: !llvm.ptr<1>, %base_width_a: i32, %base_hei llvm.return %loaded_a : vector<8xi16> } +// ----- + +llvm.func external @resolve_foo() -> !llvm.ptr attributes {dso_local} +// expected-error@+1 {{'llvm.mlir.ifunc' op resolver must be a definition}} +llvm.mlir.ifunc external @foo : !llvm.func, !llvm.ptr @resolve_foo {dso_local} + +// ----- + +llvm.mlir.global external @resolve_foo() : !llvm.ptr +// expected-error@+1 {{'llvm.mlir.ifunc' op must have a function resolver}} +llvm.mlir.ifunc external @foo : !llvm.func, !llvm.ptr @resolve_foo {dso_local} + +// ----- + +llvm.mlir.global external @resolve_foo() : !llvm.ptr +// expected-error@+1 {{'llvm.mlir.ifunc' op 'common' linkage not supported in ifuncs}} +llvm.mlir.ifunc common @foo : !llvm.func, !llvm.ptr @resolve_foo {dso_local} diff --git a/mlir/test/Target/LLVMIR/Import/ifunc.ll b/mlir/test/Target/LLVMIR/Import/ifunc.ll new file mode 100644 index 0000000000000..56cd4b7215353 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/ifunc.ll @@ -0,0 +1,80 @@ +; RUN: mlir-translate --import-llvm %s --split-input-file | FileCheck %s + +; CHECK: llvm.mlir.ifunc external @foo : !llvm.func, !llvm.ptr @resolve_foo {dso_local} +@foo = dso_local ifunc void (ptr, i32), ptr @resolve_foo + +define dso_local void @call_foo(ptr noundef %0, i32 noundef %1) { + %3 = alloca ptr, align 8 + %4 = alloca i32, align 4 + store ptr %0, ptr %3, align 8 + store i32 %1, ptr %4, align 4 + %5 = load ptr, ptr %3, align 8 + %6 = load i32, ptr %4, align 4 +; CHECK: llvm.call @foo + call void @foo(ptr noundef %5, i32 noundef %6) + ret void +} + +define dso_local void @call_indirect_foo(ptr noundef %0, i32 noundef %1) { + %3 = alloca ptr, align 8 + %4 = alloca i32, align 4 + %5 = alloca ptr, align 8 +; CHECK: [[CALLEE:%[0-9]+]] = llvm.mlir.addressof @foo +; CHECK: llvm.store [[CALLEE]], [[STORED:%[0-9]+]] +; CHECK: [[LOADED_CALLEE:%[0-9]+]] = llvm.load [[STORED]] + store ptr %0, ptr %3, align 8 + store i32 %1, ptr %4, align 4 + store ptr @foo, ptr %5, align 8 + %6 = load ptr, ptr %5, align 8 + %7 = load ptr, ptr %3, align 8 + %8 = load i32, ptr %4, align 4 + call void %6(ptr noundef %7, i32 noundef %8) + ret void +} + +define internal ptr @resolve_foo() { + ret ptr @foo_1 +} + +declare void @foo_1(ptr noundef, i32 noundef) + +; // ----- + +define ptr @resolver() { + ret ptr inttoptr (i64 333 to ptr) +} + +@resolver_alias = alias ptr (), ptr @resolver +@resolver_alias_alias = alias ptr (), ptr @resolver_alias + +; CHECK-DAG: llvm.mlir.ifunc external @ifunc : !llvm.func, !llvm.ptr @resolver_alias +@ifunc = ifunc float (i64), ptr @resolver_alias +; CHECK-DAG: llvm.mlir.ifunc external @ifunc2 : !llvm.func, !llvm.ptr @resolver_alias_alias +@ifunc2 = ifunc float (i64), ptr @resolver_alias_alias + +; // ----- + +define ptr @resolver() { + ret ptr inttoptr (i64 333 to ptr) +} + +; CHECK: llvm.mlir.ifunc linkonce_odr hidden @ifunc +@ifunc = linkonce_odr hidden ifunc float (i64), ptr @resolver + +; // ----- + +define ptr @resolver() { + ret ptr inttoptr (i64 333 to ptr) +} + +; CHECK: llvm.mlir.ifunc private @ifunc {{.*}} {dso_local} +@ifunc = private dso_local ifunc float (i64), ptr @resolver + +; // ----- + +define ptr @resolver() { + ret ptr inttoptr (i64 333 to ptr) +} + +; CHECK: llvm.mlir.ifunc weak @ifunc +@ifunc = weak ifunc float (i64), ptr @resolver diff --git a/mlir/test/Target/LLVMIR/ifunc.mlir b/mlir/test/Target/LLVMIR/ifunc.mlir new file mode 100644 index 0000000000000..d10392d2bac2a --- /dev/null +++ b/mlir/test/Target/LLVMIR/ifunc.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-translate -mlir-to-llvmir %s --split-input-file | FileCheck %s + +// CHECK: @foo = dso_local ifunc void (ptr, i32), ptr @resolve_foo +llvm.mlir.ifunc external @foo : !llvm.func, !llvm.ptr @resolve_foo {dso_local} +llvm.func @call_foo(%arg0: !llvm.ptr {llvm.noundef}, %arg1: i32 {llvm.noundef}) attributes {dso_local} { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr + llvm.store %arg0, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr + llvm.store %arg1, %2 {alignment = 4 : i64} : i32, !llvm.ptr + %3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr + %4 = llvm.load %2 {alignment = 4 : i64} : !llvm.ptr -> i32 +// CHECK: call void @foo + llvm.call @foo(%3, %4) : (!llvm.ptr {llvm.noundef}, i32 {llvm.noundef}) -> () + llvm.return +} +llvm.func @call_indirect_foo(%arg0: !llvm.ptr {llvm.noundef}, %arg1: i32 {llvm.noundef}) attributes {dso_local} { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.addressof @foo : !llvm.ptr + %2 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr + %3 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %4 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr + llvm.store %arg0, %2 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr + llvm.store %arg1, %3 {alignment = 4 : i64} : i32, !llvm.ptr +// CHECK: store ptr @foo, ptr [[STORED:%[0-9]+]] + llvm.store %1, %4 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr +// CHECK: [[LOADED:%[0-9]+]] = load ptr, ptr [[STORED]] + %5 = llvm.load %4 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr + %6 = llvm.load %2 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr + %7 = llvm.load %3 {alignment = 4 : i64} : !llvm.ptr -> i32 +// CHECK: call void [[LOADED]] + llvm.call %5(%6, %7) : !llvm.ptr, (!llvm.ptr {llvm.noundef}, i32 {llvm.noundef}) -> () + llvm.return +} +llvm.func internal @resolve_foo() -> !llvm.ptr attributes {dso_local} { + %0 = llvm.mlir.addressof @foo_1 : !llvm.ptr + llvm.return %0 : !llvm.ptr +} +llvm.func @foo_1(!llvm.ptr {llvm.noundef}, i32 {llvm.noundef}) + +// ----- + +llvm.mlir.alias external @resolver_alias : !llvm.func { + %0 = llvm.mlir.addressof @resolver : !llvm.ptr + llvm.return %0 : !llvm.ptr +} +llvm.mlir.alias external @resolver_alias_alias : !llvm.func { + %0 = llvm.mlir.addressof @resolver_alias : !llvm.ptr + llvm.return %0 : !llvm.ptr +} + +// CHECK-DAG: @ifunc = ifunc float (i64), ptr @resolver_alias +// CHECK-DAG: @ifunc2 = ifunc float (i64), ptr @resolver_alias_alias +llvm.mlir.ifunc external @ifunc2 : !llvm.func, !llvm.ptr @resolver_alias_alias +llvm.mlir.ifunc external @ifunc : !llvm.func, !llvm.ptr @resolver_alias +llvm.func @resolver() -> !llvm.ptr { + %0 = llvm.mlir.constant(333 : i64) : i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr + llvm.return %1 : !llvm.ptr +} + +// ----- + +// CHECK: @ifunc = linkonce_odr hidden ifunc + +llvm.mlir.ifunc linkonce_odr hidden @ifunc : !llvm.func, !llvm.ptr @resolver {dso_local} +llvm.func @resolver() -> !llvm.ptr { + %0 = llvm.mlir.constant(333 : i64) : i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr + llvm.return %1 : !llvm.ptr +} + +// ----- + +// CHECK: @ifunc = private ifunc + +llvm.mlir.ifunc private @ifunc : !llvm.func, !llvm.ptr @resolver {dso_local} +llvm.func @resolver() -> !llvm.ptr { + %0 = llvm.mlir.constant(333 : i64) : i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr + llvm.return %1 : !llvm.ptr +} + +// ----- + +// CHECK: @ifunc = weak ifunc + +llvm.mlir.ifunc weak @ifunc : !llvm.func, !llvm.ptr @resolver +llvm.func @resolver() -> !llvm.ptr { + %0 = llvm.mlir.constant(333 : i64) : i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr + llvm.return %1 : !llvm.ptr +}