Sharded Matrices and How to Multiply Them | How To Scale Your Model #5
Replies: 30 comments 51 replies
-
In the solution for pop quiz 2, the bidirectional ICI bandwidth for a TPU v5e is given as 9e10 bytes/s, which doesn't quite match the value of 1e11 bytes/s given in the table in part 2. Looking at https://cloud.google.com/tpu/docs/v5e, it appears that the value in the table is the correct one. |
Beta Was this translation helpful? Give feedback.
-
In the Section - "A quick aside: how would we describe this in code?" "For instance, in the above example, the local shape of A is [4, 1024] and for B is [2048, 4096]" I think local shape of A is [2, 1024]? |
Beta Was this translation helpful? Give feedback.
-
On the first picture, you state that the shape of matrix A before sharding is [ For me, this would mean that A is sharded across its rows, and B is sharded across its columns, thus we have everything to calculate a single element of the result C, because the contracting dimensions are not sharded. But because you reversed the meaning if Could you enlight this with an image? I think I get the point, but visually it would help a lot what you mean exactly with these |
Beta Was this translation helpful? Give feedback.
-
Some issues with question 2: In part 2's solution, I think you mean for X to be in the denominator. The result is the same because X = Y in this case. In part 3's solution, you mention TPU v5e, but the question asks about v4p. In part 4, I'm not sure what AllGather with a {U_Z} dimension means. I believe this is not addressed in the text of the chapter. Also, the solution again mentions v5e. |
Beta Was this translation helpful? Give feedback.
-
The code below it and the in the code really says: 8 TPUs into 4x2 grid. |
Beta Was this translation helpful? Give feedback.
-
I believe question 4 may have miscalculated the comms overhead for |
Beta Was this translation helpful? Give feedback.
-
The flow in this chapter is a little jarring when it drops into the four cases without defining the term "contracting dimensions" or doing other setup to smooth the transition. Maybe an external reference or a bit more connective flow would help? |
Beta Was this translation helpful? Give feedback.
-
In the solution to question 4, I believe it should be D < C / Wici instead of F < C / Wici when calculating when we are comms bound in strategy 1. The wording is also a bit confusing because it says "In the second case (baseline)", but it appears to be talking about strategy 1 if I'm not mistaken? Also a small grammatical error at the end of the solution - "we'll shard our parameters" instead of "we'll sharded our parameters". |
Beta Was this translation helpful? Give feedback.
-
The text says "For example, A[IX,J]⋅B[J,K]→C[IX,K] can be multiplied without any communication because the contracting dimension (J, the one we’re actually summing over) is unsharded. However, if we wanted the output unsharded (i.e. A[IX,J]⋅B[J,K]→C[IX,K]), we would need to copy A or C to every device.". Presumably the last "C[IX,K]" should actually be "C[I,K]" |
Beta Was this translation helpful? Give feedback.
-
As someone who is fairly familiar with sharding and JAX, I think the flow of this chapter can be refined and the details (along with the notations) can be improved a lot. I am happy to contribute if you guys are open to contributions? I mean it when I say this is confusing and can be simplified |
Beta Was this translation helpful? Give feedback.
-
Can you explain more about AllReduce? Because I think I misunderstand what this actually do in Question2, Part 3. In my opinion, after we do because there is no communication between X and Y? |
Beta Was this translation helpful? Give feedback.
-
In Question 3, why the answer says "Since we have an axis of size 4 on a TPU v4p, we have a wraparound link, so we can do the AllGather by sending half the bytes in each direction". In the GIF above, I think each device sending the whole bytes in each direction? Is there any difference? |
Beta Was this translation helpful? Give feedback.
-
Thanks for the great work! I have a question in bi-directional all-gather case: since each hop sends |
Beta Was this translation helpful? Give feedback.
-
this is a fantastic book! Kudos to the authors and big THANK YOU! I think this section is super critical in appreciation of TPU differentiation vs GPU but needs quite substantial rework:
I hope my feedback is not misconstrued. I feel this book overall is phenomenal in its objectives and style, and definitely stands out in the crowd of similar efforts. Thank you again! |
Beta Was this translation helpful? Give feedback.
-
In question 4, I believe some math + reasoning for All-Gather being the preferred strategy is incorrect. At the beginning,
So for reasonably common batch sizes, we're ICI-bound for strategy 1, as we are for strategy 2. In that case, need to compare ICI times for both strategies to decide which one is best. Strategy 2 is best when:
So basically, for reasonable batch sizes (~1-2K) and D (~4K) strategy 2 is better than strategy 1. I also built a bunch of plots in this Colab, which showed that for certain large values of D & F it's never even beneficial to do strategy 1 (for example, when D=8K, F=16K) while for other values (D=4K, F=16K) it's better to do strategy 2 for B<2K and then it's slightly better to do strategy 1 for larger values of B Unless I screwed up doing my math above, I believe the recommendation that the "All-Gather" strategy is better for Case 2 should be reconsidered. At smaller batches, the "All-Reduce" strategy seems to be much better. It also makes sense when reasoning about it at the high level: when you have a giant weight matrix (i.e. -- A small nit re. the same question: it never mentions we want to do everything in bfloat16, would be great to add that info. -- Thank you for reading and also thank you for providing such a great learning resource for the community! |
Beta Was this translation helpful? Give feedback.
-
Thanks for this book -- I've learnt a lot! I have a question about the calculation for the time it takes to do an AllGather, where the conclusion was that the time does not depend on
Of course, if I also like this other way of reasoning about how much time it should take: each of the |
Beta Was this translation helpful? Give feedback.
-
Thanks for the visuals clarity fix, Jacob! much appreciated |
Beta Was this translation helpful? Give feedback.
-
In A Deeper Dive into TPU Communication Primitives I would add intuition behind "mechanics" of matrix/shards juggling. The Whys
|
Beta Was this translation helpful? Give feedback.
-
This article is very useful for me coming from GPU world without any TPU background before! I think it worth pointing out that the |
Beta Was this translation helpful? Give feedback.
-
Hi Jacob, new poster here, thank you for this blog post. This might be a stupid question but it seems to me that online resources (yours included) suggest when we shard our data across various devices, the dimension of the data is always (Batch size, Length). It seems that the embedding of the data is done after it has been distributed across the devices. I was wondering what if in the case, how could we shard a batch of pre-embedded multidimensional data, i.e. (Batch Size, d1, d2, d3)? I'm currently working on training an equivarient neural network that ingests crystal data to learn some latent space, however, the runtime complexity is quite heavy so I am trying to distribute the training across a GPU cluster. My data is represented as a 3D incidence matrix where the channels are embedding representation of Nodes and Edges so the dimensions would be (Node Len, Edge Len, Embedding Dim). When I batch this data, I would get (Batch, Node, Edge, Embed). Is the best practice to just embed the data during the training iteration or is there a workaround to distributed these multidimensional batches. I'm relatively new to JAX and distributed training as a whole so it doesn't seem there are a lot of resources around this. |
Beta Was this translation helpful? Give feedback.
-
I'm thinking about the equivalent of TPU axis in a GPU server. For a fully connected GPU clique of N devices (say 8 or 16 GPUs connected by NVLink / NVSwitch), is the bandwidth basically N * unidirectional link speed? Since a TPU is connected at most to 6 neighbours (max speed 6 * uni-W_ICI), it seems to me for communication operations, TPUs would be much slower? (Ignoring cost) |
Beta Was this translation helpful? Give feedback.
-
Hi Jacob, thanks for the great knowledge sharing! Regarding the |
Beta Was this translation helpful? Give feedback.
-
In question 10.1, why is the number of floats communicated by a ReduceScatter the same as that of AllGather? Doesn't ReduceScatter need to communicate less since the partial sums remain scattered and don't need to be gathered? |
Beta Was this translation helpful? Give feedback.
-
In Pop quiz 2 Part 1, I wonder if we should use unidirectional bandwidth (which is 4.5e10) because Y axis size is smaller than 16. IIUC, the answer should be Tcomms=34e6/4.5e10=756μs. I'm curious if I'm missing something. |
Beta Was this translation helpful? Give feedback.
-
Hi, thank you for the great explanation. I have a question regarding 10.2. Why is the data size considered to be |
Beta Was this translation helpful? Give feedback.
-
Could you clarify the I also don't have a great intuition of how an |
Beta Was this translation helpful? Give feedback.
-
In Case 3: both multiplicands have sharded contracting dimensions, is the reduce scatter done via bf16 or f32? typically matmul accumulation we need to do with f32, does it mean the communication cost for reduce scatter will be higher? |
Beta Was this translation helpful? Give feedback.
-
For question 7 I believe I might be misunderstanding the notation. We want to multiple matrices C and B, and take the result and multiply by matrix x correct? In this case, it appears the shapes are incompatible, the result of C * B is [F, F] which is incompatible with the shape x of [B, D]. |
Beta Was this translation helpful? Give feedback.
-
In the first pop quiz, you write that
but this is wrong because 128 * 2048 * 2 = 524,288 = 524kiB. |
Beta Was this translation helpful? Give feedback.
-
This isn't working in the Colab:
Should be changed to this:
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Sharded matrix multiplications galore!
Beta Was this translation helpful? Give feedback.
All reactions