Replies: 1 comment 5 replies
-
One problem is that you are creating the mesh with Manual mesh axis. If you don't do that and just pass the mesh to shard_map, it should work or give you a different error. shard_map will switch the mesh axes for you. |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
I'm working on a likelihood function that uses
jax.experimental.shard_map
for memory-efficient computation across multiple CPU devices. This function is eventually called inside a library that appliesjax.vmap
over it — so I can't remove or control thevmap
.However, I'm hitting this error:
The traceback shows it originates from inside
shard_map
when trying to infer the shape:I'm guessing this means that under
vmap
, the global shape has an extra leading batch dimension that conflicts with the definedPartitionSpec
.Here's the minimal version of what I’m doing:
My question is:
How can I make
likelihood_fn
compatible with an externalvmap
, while keepingshard_map
working on the already sharded data?Key constraints:
vmap
is imposed by a library I’m using — I can’t remove or rewrite it.sharded_data
andsharded_log_ref_priors
are constant inputs.shard_map
to parallelize across devices, andvmap
to evaluate the function across parameter samples.Would wrapping the
likelihood_fn
body withjax.named_call
, or usingwith_sharding_constraint
, help here? Or is there a better way to "freeze" the sharded arrays so the outervmap
doesn’t try to batch them?Thanks in advance for any insights — really stuck here.
Beta Was this translation helpful? Give feedback.
All reactions