Skip to content

How to implement batching rule for custom cuda op? #16840

Answered by jakevdp
frskplis asked this question in Q&A
Discussion options

You must be logged in to vote

The semantics of your batching rule should be something like this:

def _sumcumprod_batch(args, axes):
    x, = args
    bd, = axes
    x = jnp.moveaxis(x, bd, 0)
    x_slices = [x[i] for i in range(x.shape[-1])]
    result_slices = [sumcumprod(x_slice) for x_slice in x_slices]
    return jnp.stack(result_slices), 0

I think if you plug this in, it will work for your case, but unfortunately it's not very efficient due to the for loop within the list comprehension. But unless you generalize your primitive to be closed under batching, it's hard to do much better than this.

What do I mean by "closed under batching"? Consider the example of computing a vector product: in JAX, a simple vector pr…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@frskplis
Comment options

Answer selected by frskplis
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants