-
I have recently picked up JAX and converted a project I'm working on from pytorch. It's a bit of a learning curve but I love functional programming so it's fun. I discovered that google colab offers an 8 TPU session, and wanted to understand how I can optimise my sharding for the calculations I am doing. Are there any good resources for reading about how sharding is commonly thought about? From what I understand, I shard the inputs to calculations, and then let the compiler figure out how to handle the rest. For example, I can shard a batch into 8 parts, and then duplicate the params across all devices and perform the calculation in parallel. Could I also shard the parameters into different devices, and would that yield a speed up? If I have 4 layers that run sequentially, I can't imagine how forcing the data to be passed between devices would be quicker. On the other hand, if I shard the params matrices, and perform the calculation in parallel and then combine the results - perhaps this would be quicker? Any answers or direction would be appreciated :) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hey! I would recommend reading these docs:
For the techniques you mentioned, it really depends on what you are trying to do. Maybe https://jax-ml.github.io/scaling-book/training/ can help? This doc covers the techniques you were asking about (Data parallelism, FSDP and TP) |
Beta Was this translation helpful? Give feedback.
Hey!
I would recommend reading these docs:
For the techniques you mentioned, it really depends on what you are trying to do. Maybe https://jax-ml.github.io/scaling-book/training/ can help? This doc covers the techniques you were asking about (Data parallelism, FSDP and TP)