Skip to content

Slow compilation when jitting method of a class #16020

Answered by jakevdp
IrishWhiskey asked this question in Q&A
Discussion options

You must be logged in to vote

The issue is that you are marking self as static. This is problematic for a number of reasons, as described in FAQ: How to use JIT with Methods.

In particular, if self is marked static then any attributes of self will be embedded as literal constants in the compiled program, and embedding a 600,000-element constant is going to lead to poor performance (thus the warning you're seeing).

I would recommend following Option 3 at the above link; it might look something like this:

import numpy as np
import jax
import jax.numpy as jnp
from functools import partial

np.random.seed(0)

@jax.tree_util.register_pytree_node_class
class C:
    def __init__(self, d, w):
        self.d = d
        self.w = 

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
4 replies
@IrishWhiskey
Comment options

@maurorigo
Comment options

@jakevdp
Comment options

@maurorigo
Comment options

Answer selected by IrishWhiskey
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants