Replies: 1 comment 3 replies
-
I think this will be the way to go. In jax internals, primitive "parameters" supply static information, whereas "operands" are possibly-dynamic. The fact that To exclude By the way, we have a proposal for adding a proper FFI API in #12632, so that hopefully you won't need to dig into jax internals/primitives in the future in order to set up custom calls. Feedback there is welcome. (cc @sharadmv) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Jax community,
I am writing a primitive that makes an XLA custom call (to a C++ function).
I know how to do it when my primitive does not have any keywords (parameters).
My problem is how to make it work when the primitive has parameters.
Here are some details:
y = foo(x, *, p)
.x
is a regular array. I want to be able to differentiate with respect to it. etc.p
is a parameter. In my case, it is a an array of integers.y
is the same for allx
andp
.x
, say it is 1D array of sizeN
, a pointer to the data inx
and the same forp
(say also 1D, integers, of sizeD
).This does not work because
p
is not an XlaOp. It is aTraced<ShapedArray(int32[1000])>with<DynamicJaxprTrace(level=0/1)>
.Note that everything works fine if I treat
p
just likex
(as a regular input, not as a parameter).I'm okay with a solution that treats
p
just likex
for the XLA call if I can exclude p from JVP/VJP.Regards,
ib
Beta Was this translation helpful? Give feedback.
All reactions