Programming TPUs in JAX | How To Scale Your Model #12
jacobaustin123
started this conversation in
General
Replies: 1 comment 2 replies
-
Hi! Thanks for the great book. I believe that Problem 1's attached solution doesn't correspond to the problem description (it's only on one axis). I have been getting an issue with my solution and have written a Github Issue. I would really appreciate if you could help me out I also have gotten another error I couldn't solve. I run this line (as described in this chapter): replicated = jax.lax.with_sharding_constraint(replicated, P(None, 'Y')) and I keep getting an error:
has the API perhaps changed? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
JAX programming, broadly construed!
Beta Was this translation helpful? Give feedback.
All reactions