-
Consider this example function inspired by Huggingface's implementation of Dreambooth. def train_step(model1, model2, model3, model1_params, model2_params, model3_params):
def loss_fn(params):
...
...
Here's a couple things I've considered. In the Huggingface implementation, they solve this issue by making def main():
model1 = ...
model2 = ...
model3 = ...
...
def train_step(model1_params, model2_params, model3_params):
... But this results in a pretty long function definition which is hard to read and modify. The Flax's Quick Start suggests a more structured approach using Are there any other strategies for jit-compiling functions with non-pytree arguments? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Thanks for the question! Two ideas:
The latter has the same behavior as the lexical closure approach used by HF. The difference with the former is requiring hashability (with the hashability providing more opportunities for cache hits when models compare equal). In more detail, the first option might look like def train_step(model1, model2, model3, model1_params, model2_params, model3_params):
def loss_fn(params):
...
...
train_step = jax.jit(train_step, static_argnums=(0, 1, 2))
# a call site might look something like this
train_step(model1, model2, model3, model1_params, model2_params, model3_params) The second option might look like from functools import partial
def train_step(model1, model2, model3, model1_params, model2_params, model3_params):
def loss_fn(params):
...
...
train_step = jax.jit(partial(train_step, model1, model2, model3))
# or even, without needing to import `partial`
train_step = jax.jit(lambda *args: train_step(model1, model2, model3, *args))
# a call site might look something like this
train_step(model1_params, model2_params, model3_params) What do you think? |
Beta Was this translation helpful? Give feedback.
-
Indeed, the import jax
import jax.numpy as jnp
import flax.linen as nn
from functools import partial
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=10)(x)
return x
def train_step(model, model_params, x):
return model.apply({"params": model_params}, x)
def main():
rng = jax.random.PRNGKey(0)
model = MLP()
model_params = model.init(rng, jnp.ones([1, 64]))['params']
x = jax.random.normal(rng, (1, 64))
train_step(model, model_params, x)
partial_train_step = jax.jit(partial(train_step, model))
partial_train_step(model_params, x) Thank you for the help! |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
Two ideas:
jit
'sstatic_argnums
(andstatic_argnames
if you prefer).functools.partial
.The latter has the same behavior as the lexical closure approach used by HF. The difference with the former is requiring hashability (with the hashability providing more opportunities for cache hits when models compare equal).
In more detail, the first option might look like