Skip to content

Commit 0ec2b2f

Browse files
committed
Fixes to align with latest autoparallel
1 parent c8fb6b5 commit 0ec2b2f

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def input_fn():
5656
# model = model_fn()
5757
# return model
5858

59-
autop = AutoParallel(model, input_fn, world_mesh, device=world_mesh.device_type)
59+
autop = AutoParallel(model, input_fn, world_mesh)
6060
autop.add_parameter_memory_constraint(low=None, high=None)
6161

6262
x_sharding = (Shard(0), Replicate())

torchtitan/train.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from torchtitan.config_manager import ConfigManager, JobConfig
2525
from torchtitan.distributed import ParallelDims, utils as dist_utils
26-
26+
from torch.distributed.tensor import DTensor
2727
from torchtitan.protocols.model_converter import build_model_converters
2828
from torchtitan.tools import utils
2929
from torchtitan.tools.logging import init_logger, logger
@@ -158,10 +158,13 @@ def param(name):
158158
from torchtitan.models.llama3.model import precompute_freqs_cis
159159

160160
model.buffers_.get_buffer("freqs_cis").copy_(
161-
precompute_freqs_cis(
162-
model_args.dim // model_args.n_heads,
163-
model_args.max_seq_len,
164-
model_args.rope_theta,
161+
DTensor.from_local(
162+
precompute_freqs_cis(
163+
model_args.dim // model_args.n_heads,
164+
model_args.max_seq_len,
165+
model_args.rope_theta,
166+
),
167+
device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh,
165168
)
166169
)
167170

0 commit comments

Comments
 (0)