Skip to content

Commit af087e0

Browse files
joker-ephmemfrob
authored andcommitted
Add a "register_runtime" method to the mlir.execution_engine and show calling back from MLIR into Python
This exposes the ability to register Python functions with the JIT and exposes them to the MLIR jitted code. The provided test case illustrates the mechanism. Differential Revision: https://reviews.llvm.org/D99562
1 parent 354fa1f commit af087e0

File tree

5 files changed

+71
-1
lines changed

5 files changed

+71
-1
lines changed

mlir/include/mlir-c/ExecutionEngine.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked(
6161
MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
6262
MlirStringRef name);
6363

64+
/// Register a symbol with the jit: this symbol will be accessible to the jitted
65+
/// code.
66+
MLIR_CAPI_EXPORTED void
67+
mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, MlirStringRef name,
68+
void *sym);
69+
6470
#ifdef __cplusplus
6571
}
6672
#endif

mlir/lib/Bindings/Python/ExecutionEngine.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,17 @@ void mlir::python::populateExecutionEngineSubmodule(py::module &m) {
8181
auto *res = mlirExecutionEngineLookup(
8282
executionEngine.get(),
8383
mlirStringRefCreate(func.c_str(), func.size()));
84-
return (int64_t)res;
84+
return reinterpret_cast<uintptr_t>(res);
85+
},
86+
"Lookup function `func` in the ExecutionEngine.")
87+
.def(
88+
"raw_register_runtime",
89+
[](PyExecutionEngine &executionEngine, const std::string &name,
90+
uintptr_t sym) {
91+
mlirExecutionEngineRegisterSymbol(
92+
executionEngine.get(),
93+
mlirStringRefCreate(name.c_str(), name.size()),
94+
reinterpret_cast<void *>(sym));
8595
},
8696
"Lookup function `func` in the ExecutionEngine.");
8797
}

mlir/lib/Bindings/Python/mlir/execution_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,11 @@ def invoke(self, name, *ctypes_args):
2929
for argNum in range(len(ctypes_args)):
3030
packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
3131
func(packed_args)
32+
33+
def register_runtime(self, name, ctypes_callback):
34+
"""Register a runtime function available to the jitted code
35+
under the provided `name`. The `ctypes_callback` must be a
36+
`CFuncType` that outlives the execution engine.
37+
"""
38+
callback = ctypes.cast(ctypes_callback, ctypes.c_void_p).value
39+
self.raw_register_runtime("_mlir_ciface_" + name, callback)

mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/CAPI/IR.h"
1212
#include "mlir/CAPI/Support.h"
1313
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
14+
#include "llvm/ExecutionEngine/Orc/Mangling.h"
1415
#include "llvm/Support/TargetSelect.h"
1516

1617
using namespace mlir;
@@ -54,3 +55,14 @@ extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
5455
return nullptr;
5556
return reinterpret_cast<void *>(*expectedFPtr);
5657
}
58+
59+
extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
60+
MlirStringRef name,
61+
void *sym) {
62+
unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
63+
llvm::orc::SymbolMap symbolMap;
64+
symbolMap[interner(unwrap(name))] =
65+
llvm::JITEvaluatedSymbol::fromPointer(sym);
66+
return symbolMap;
67+
});
68+
}

mlir/test/Bindings/Python/execution_engine.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,37 @@ def testInvokeFloatAdd():
9797
log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
9898

9999
run(testInvokeFloatAdd)
100+
101+
102+
# Test callback
103+
# CHECK-LABEL: TEST: testBasicCallback
104+
def testBasicCallback():
105+
# Define a callback function that takes a float and an integer and returns a float.
106+
@ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
107+
def callback(a, b):
108+
return a/2 + b/2
109+
110+
with Context():
111+
# The module just forwards to a runtime function known as "some_callback_into_python".
112+
module = Module.parse(r"""
113+
func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
114+
%resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
115+
return %resf : f32
116+
}
117+
func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
118+
""")
119+
execution_engine = ExecutionEngine(lowerToLLVM(module))
120+
execution_engine.register_runtime("some_callback_into_python", callback)
121+
122+
# Prepare arguments: two input floats and one result.
123+
# Arguments must be passed as pointers.
124+
c_float_p = ctypes.c_float * 1
125+
c_int_p = ctypes.c_int * 1
126+
arg0 = c_float_p(42.)
127+
arg1 = c_int_p(2)
128+
res = c_float_p(-1.)
129+
execution_engine.invoke("add", arg0, arg1, res)
130+
# CHECK: 42.0 + 2 = 44.0
131+
log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2))
132+
133+
run(testBasicCallback)

0 commit comments

Comments
 (0)