Skip to content

My value_and_grad taking longer than expected (I'm sure I'm doing something wrong) #17183

Closed Answered by PaulScemama
PaulScemama asked this question in Q&A
Discussion options

You must be logged in to vote

I forgot to jax.jit the value_and_grad function. To do this with methods within a dataclass you need to do the following:

@dataclass(frozen=True, eq=True) # for hashability
class foo:
  ...

  @partial(jax.jit, static_argnums=(0, ..possibly more))
  def method:
         ...do stuff

In my case this looks like:

@dataclass(eq=True, frozen=True)
class LogProbability:
    likelihood: Callable  # but not really
    prior: Callable  # but not really
    num_batches: int

    @partial(jax.jit, static_argnums=(0,1))
    def value(
        self,
        net_apply: Callable,
        net_state: Dict,
        params: Dict,
        x: jax.Array,
        y: jax.Array,
        is_training: bool = False,
…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by PaulScemama
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant