How to bind a void pointer in the new FFI API #27763
-
Hi @dfm , I am binding a C library to jax via the new FFI API. However, It's unclear for me how to bind a void pointer (the app below, actually a pointer to a C struct that stores configs of the computation to be run) I tried .Ctx and .Attr tags, but none of them works. Note the pointer pointed to a highly complex struct defined in https://github.com/DTolm/VkFFT/blob/066a17c17068c0f11c9298d848c2976c71fad1c1/vkFFT/vkFFT/vkFFT_Structs/vkFFT_Structs.h#L324 namespace ffi = xla::ffi;
// Forward FFT implementation function
ffi::Error VkFFTForwardImpl(void *app, ffi::AnyBuffer input,
ffi::AnyBuffer kernel, // This could be an empty buffer
ffi::Result<ffi::AnyBuffer> output)
{
// Cast the void* to VkFFTApplication*
VkFFTApplication *vkapp = static_cast<VkFFTApplication*>(app);
// Only use kernel if it has elements
void *kernel_ptr = (kernel.element_count() > 0) ? kernel.untyped_data() : NULL;
// Call the VkFFT fft function with possibly NULL kernel
int result = fft(vkapp,
input.untyped_data(),
output->untyped_data(),
kernel_ptr);
if (result != 0)
{
return ffi::Error::Internal("VkFFT forward execution failed with code: " + std::to_string(result));
}
return ffi::Error::Success();
}
// Inverse FFT implementation function
ffi::Error VkFFTInverseImpl(void *app, ffi::AnyBuffer input,
ffi::Result<ffi::AnyBuffer> output)
{
// Cast the void* to VkFFTApplication*
VkFFTApplication *vkapp = static_cast<VkFFTApplication*>(app);
// Call the VkFFT ifft function
int result = ifft(vkapp,
input.untyped_data(),
output->untyped_data());
if (result != 0)
{
return ffi::Error::Internal("VkFFT inverse execution failed with code: " + std::to_string(result));
}
return ffi::Error::Success();
}
// Update the handler registration
XLA_FFI_DEFINE_HANDLER_SYMBOL(VkFFTForward, VkFFTForwardImpl,
ffi::Ffi::Bind()
.Ctx<void *>()
.Arg<ffi::AnyBuffer>() // input
.OptionalArg<ffi::AnyBuffer>() // kernel
.Ret<ffi::AnyBuffer>() // output
);
XLA_FFI_DEFINE_HANDLER_SYMBOL(VkFFTInverse, VkFFTInverseImpl,
ffi::Ffi::Bind()
.Ctx<void *>()
.Arg<ffi::AnyBuffer>() // input
.Ret<ffi::AnyBuffer>() // output
); |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 9 replies
-
Good question! I believe that the following should work: ffi::Error Impl(void* ptr) {...}
XLA_FFI_DEFINE_HANDLER_SYMBOL(..., Impl,
ffi::Ffi::Bind()
.Attr<ffi::Pointer<void>>()
); Then, from Python, you would pass the pointer as a |
Beta Was this translation helpful? Give feedback.
Many questions here - let me try to answer them all!
If you can already call your function from Python, another option would be to use
jax.pure_callback
. You can typically get better performance and more features (e.g. in-place operations as discussed below) via the FFI, at the cost of an extra interface code.Yep! The parameter you're looking for is
input_output_aliases
forffi_call
.