Replies: 3 comments 2 replies
-
I'm going to use |
Beta Was this translation helpful? Give feedback.
-
You work around the original jax.jit(lambda: params, out_shardings=param_spec)() # no crash We're working on improving the error message and possibly changing the API to accept cc @yashk2810 |
Beta Was this translation helpful? Give feedback.
-
Any update on this? The documentation says |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi :) I'm really excited to try sharding weights across devices! This will help me make much better use of access to TPUv2-8 VMs. I found this issue that seems really similar to the one that I'm having, but it's been closed. I've been struggling to port this code to work with stable diffusion. I'll try to illustrate what happens when I try to use the
pjit
merge to factor out thepmap
decorating mytrain_step
. If I strip out some initialization and optimization code, what I end up with is:The code fails with a
RuntimeError: jit does not support using the mesh context manager and passing PartitionSpecs to in_shardings or out_shardings. Please pass in the
Shardingexplicitly via in_shardings or out_shardings.
This is puzzling, considering that I created aMesh
object, but didn't use it as a context manager, and did exactly as instructed by the error, explicitly specifyingNamedSharding
s for each leaf of the parameter tree. It seems to work when I use it explicitly for device placement. What's actually going on here?Edit: made the 'minimal example' more minimal 😅
Update: Tried with a
PositionalSharding
instead, that crashed with an identical traceBeta Was this translation helpful? Give feedback.
All reactions