Skip to content

Strategy for jit-compiling functions with non-pytree arguments #14689

Answered by mattjj
dfdx asked this question in General
Discussion options

You must be logged in to vote

Thanks for the question!

Two ideas:

  1. If the model objects are hashable, use jit's static_argnums (and static_argnames if you prefer).
  2. If they're not hashable, use functools.partial.

The latter has the same behavior as the lexical closure approach used by HF. The difference with the former is requiring hashability (with the hashability providing more opportunities for cache hits when models compare equal).

In more detail, the first option might look like

def train_step(model1, model2, model3, model1_params, model2_params, model3_params):
    def loss_fn(params):
        ...
    ...

train_step = jax.jit(train_step, static_argnums=(0, 1, 2))

# a call site might look something like this
tr…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Answer selected by dfdx
Comment options

You must be logged in to vote
1 reply
@mattjj
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants