Skip to content

Commit 760c1f4

Browse files
jc-audettianyu-l
authored andcommitted
Add check for seq_len%tensor_parallel_degree==0 for parallelized Llama (#1312)
Mitigates #1306 Following discussions in #1306, the `seq_len%tensor_parallel_degree==0` seems to be a necessary condition for the tp Llama3 model to work (since it is a workaround of [this](pytorch/pytorch#130646) numerical issue in pytorch Dtensors of complex numbers. This PR makes this requirements explicit. --------- Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
1 parent 00a6cf3 commit 760c1f4

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ def parallelize_llama(
4545
NOTE: The passed-in model preferably should be on meta device. Otherwise,
4646
the model must fit on GPU or CPU memory.
4747
"""
48+
# TODO: TP currently cannot handle uneven seq_len because we set `use_local_output=True`
49+
# (to use plain Tensors), which was because of the bug in computation of complex
50+
# numbers with DTensors when setting `use_local_output=False`.
51+
# See https://github.com/pytorch/pytorch/issues/130646 and
52+
# https://github.com/pytorch/torchtitan/issues/1306 for details.
53+
assert (
54+
job_config.training.seq_len % (parallel_dims.tp * parallel_dims.cp) == 0
55+
), f"""
56+
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
57+
({parallel_dims.tp}) and CP degree ({parallel_dims.cp}).
58+
"""
4859

4960
if parallel_dims.tp_enabled:
5061
if (

0 commit comments

Comments
 (0)