-
Following extending-jax and jax tutorial I created a custom cuda kernel that calculates sum of cumulative product. I have uploaded it here https://github.com/frskplis/sumcumprod_jax. I am having trouble defining batching rule for my primitive. Original extending-jax repo had kernel that performed element wise operation therefore the batching rule was the same as original primitive. In my case my cuda kernel performs calculation that is equivalent to this simplified code:
And my actual kernel exposed to python:
They are the same as seen here:
I would like to use vmap to map this kernel over first axis of 2D array such that I would get something like:
But for my kernel this is totally wrong:
This is because I have incorrect _sumcumprod_batch in sumcumprod_jax.py source so effectively there is no batching in my case. I tried to look at batching rules in JAX source code but no avail. Please help me - what changes should I make in order for this to work correctly? Additional info: cc: @dfm |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The semantics of your batching rule should be something like this: def _sumcumprod_batch(args, axes):
x, = args
bd, = axes
x = jnp.moveaxis(x, bd, 0)
x_slices = [x[i] for i in range(x.shape[-1])]
result_slices = [sumcumprod(x_slice) for x_slice in x_slices]
return jnp.stack(result_slices), 0 I think if you plug this in, it will work for your case, but unfortunately it's not very efficient due to the What do I mean by "closed under batching"? Consider the example of computing a vector product: in JAX, a simple vector product is lowered to the x = jnp.ones(10)
y = jnp.ones(10)
jax.make_jaxpr(jnp.dot)(x, y)
# { lambda ; a:f32[10] b:f32[10]. let
# c:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] a b
# in (c,) } The parameters to this specify the dimensions to be contracted: Now what happens if we x_batched = jnp.ones((100, 10))
y_batched = jnp.ones((100, 10))
jax.make_jaxpr(jax.vmap(jnp.dot))(x_batched, y_batched)
# { lambda ; a:f32[100,10] b:f32[100,10]. let
# c:f32[100] = dot_general[dimension_numbers=(([1], [1]), ([0], [0]))] a b
# in (c,) } Again it's a single call to Thus Now back to your primitive: it does not look like your primitive is closed under batching, and this is probably something it inherits from how your CUDA kernel is implemented. So in this case there's not really any obvious efficient way to express the batched operation. So your options are:
Option 3 is probably the best, but unfortunately might also take much more work. |
Beta Was this translation helpful? Give feedback.
The semantics of your batching rule should be something like this:
I think if you plug this in, it will work for your case, but unfortunately it's not very efficient due to the
for
loop within the list comprehension. But unless you generalize your primitive to be closed under batching, it's hard to do much better than this.What do I mean by "closed under batching"? Consider the example of computing a vector product: in JAX, a simple vector pr…