Slow compilation when jitting method of a class #16020
-
I get the error Here is the code that produces the error
I'm using JAX |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
The issue is that you are marking In particular, if I would recommend following Option 3 at the above link; it might look something like this: import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
np.random.seed(0)
@jax.tree_util.register_pytree_node_class
class C:
def __init__(self, d, w):
self.d = d
self.w = w
@jax.jit
def foobar(self):
return self.d @ self.w
def tree_flatten(self):
children = (self.d, self.w) # arrays / dynamic values
aux_data = {} # static values
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
d = jnp.array(np.random.rand(200000, 3))
w = jnp.array(np.random.rand(3,))
c = C(d, w)
c.foobar() If you run this you'll find that it compiles more-or-less instantly. |
Beta Was this translation helpful? Give feedback.
-
Follow-up: if you want a nicer syntax for Jake's answer, then try Equinox. This defines |
Beta Was this translation helpful? Give feedback.
The issue is that you are marking
self
as static. This is problematic for a number of reasons, as described in FAQ: How to use JIT with Methods.In particular, if
self
is marked static then any attributes ofself
will be embedded as literal constants in the compiled program, and embedding a 600,000-element constant is going to lead to poor performance (thus the warning you're seeing).I would recommend following Option 3 at the above link; it might look something like this: