Skip to content

Commit 05d3f7c

Browse files
[llama4] add back max_len arg to generate_permute_indices (#1354)
Fixes #1353 Bug is related to this revert, the arg list changed: #1340 So need to add this arg back: 4e31b86 cc @lessw2020 @tianyu-l
1 parent 0c368fe commit 05d3f7c

File tree

1 file changed

+1
-0
lines changed
  • torchtitan/experiments/llama4/model

1 file changed

+1
-0
lines changed

torchtitan/experiments/llama4/model/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
295295
num_local_tokens_per_expert,
296296
self.experts.num_experts,
297297
1,
298+
token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M,
298299
ALIGN_SIZE_M,
299300
)
300301
token_indices = torch.vstack(

0 commit comments

Comments
 (0)