From dc195cc5d0ea6b7813f48d4bd07fc5122ca88cfb Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 12 Jun 2025 06:45:36 -0700 Subject: [PATCH 1/5] WIP: Try to use monarch to run torchtitan. Repro: GROUP_WORLD_SIZE=1 WORLD_SIZE=2 python -m torchtitan.train_monarch \ --job.config_file \ ./torchtitan/models/llama3/train_configs/llama3_8b.toml Currently there is an error: KeyError: "Invalid mesh_dim_names ('dp_cp',) specified. Valid mesh_dim_names are ['dp_shard']." --- torchtitan/distributed/parallel_dims.py | 2 + torchtitan/train.py | 7 +++ torchtitan/train_monarch.py | 66 +++++++++++++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 torchtitan/train_monarch.py diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 5f8bc5025..82aae2032 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -58,6 +58,7 @@ def build_mesh(self, device_type: str) -> DeviceMesh: ["pp", "dp_replicate", "dp_shard", "cp", "tp"], ): if d > 1: + # print(f"AHMAD HERE: {d} {name}") dims.append(d) names.append(name) @@ -100,6 +101,7 @@ def _build_mesh( mesh_dim_name="dp_shard_cp" ) if dp_cp_mesh_dim_names != []: + print(f" ==== *** AHMAD: dp_cp_mesh_dim_names {dp_cp_mesh_dim_names}") mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") return mesh diff --git a/torchtitan/train.py b/torchtitan/train.py index 9340671d7..c43c3b0c4 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -92,6 +92,8 @@ def __init__(self, job_config: JobConfig): world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) + print(f"AHMAD: Parallelism: {parallel_dims} {world_size=}") + logger.error(f" AHMAD: Parallelism: {parallel_dims} {world_size=}") dist_utils.init_distributed(job_config) # build meshes @@ -442,6 +444,8 @@ def train_step( if not self.metrics_processor.should_log(self.step): return + print(f" AHMAD: {parallel_dims.dp_replicate_enabled=} {parallel_dims.dp_shard_enabled=} {parallel_dims.cp_enabled=} {self.ft_manager.enabled=}") + print(f" AHMAD: {self.world_mesh=}") if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled @@ -455,6 +459,8 @@ def train_step( and self.job_config.fault_tolerance.semi_sync_method is None ) ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None + print(f" ==== > AHMAD {self.world_mesh=} {ft_pg=}") + # print(self.world_mesh["dp_cp"]) global_avg_loss, global_max_loss = ( dist_utils.dist_mean(loss, self.world_mesh["dp_cp"], ft_pg), dist_utils.dist_max(loss, self.world_mesh["dp_cp"], ft_pg), @@ -532,6 +538,7 @@ def close(self) -> None: if __name__ == "__main__": + print(os.environ) init_logger() config_manager = ConfigManager() config = config_manager.parse_args() diff --git a/torchtitan/train_monarch.py b/torchtitan/train_monarch.py new file mode 100644 index 000000000..c3a84e757 --- /dev/null +++ b/torchtitan/train_monarch.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import asyncio +import importlib +import os +import pickle +import sys +import time +from datetime import timedelta +from logging import getLogger +from typing import Any, Generator, Iterable, Optional +import torch +import torchtitan.components.ft as ft +import torchtitan.protocols.train_spec as train_spec_module +from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh +from monarch.actor_mesh import Actor, current_rank, endpoint +from monarch.proc_mesh import proc_mesh, ProcMesh +from monarch_meta._monarch_meta import hyperactor_meta +from torchtitan.config_manager import ConfigManager, JobConfig +from torchtitan.tools.logging import init_logger, logger +from .train import Trainer + + +class TrainerActorWrapper(Actor): + def __init__(self, job_config: JobConfig): + self.job_config = job_config + self.rank = current_rank().rank + os.environ["RANK"] = str(self.rank) + os.environ["ROLE_RANK"] = str(self.rank) + os.environ["LOCAL_RANK"] = str(self.rank % 8) + world_size = int(os.environ["WORLD_SIZE"]) + os.environ["LOCAL_WORLD_SIZE"] = str(world_size % 8) + self.trainer = Trainer(self.job_config) + + @endpoint + def train(self): + self.trainer.train() + print("hello world") + +async def async_main(job_config: JobConfig): + torch.use_deterministic_algorithms(True) + world_size = int(os.environ["WORLD_SIZE"]) + # world_size = 2 + local_proc_mesh = await proc_mesh( + gpus=world_size, + env={ + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12356", + }, + ) + print(job_config) + trainer_actor = await local_proc_mesh.spawn( + "trainer_actor", TrainerActorWrapper, job_config + ) + await trainer_actor.train.call() + + +if __name__ == "__main__": + init_logger() + config_manager = ConfigManager() + config = config_manager.parse_args() + asyncio.run(async_main(config)) + sys.exit(0) From b4ee03881093c045caf24fff078eb2ec5226edd9 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 12 Jun 2025 12:50:42 -0700 Subject: [PATCH 2/5] Seems to be working for 8 GPUs --- torchtitan/train.py | 10 +++++ torchtitan/train_monarch.py | 73 ++++++++++++++++++++++++++++++++----- 2 files changed, 73 insertions(+), 10 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index c43c3b0c4..1617e37e1 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -32,6 +32,8 @@ maybe_enable_profiling, ) +logger.info = logger.warn + class Trainer(torch.distributed.checkpoint.stateful.Stateful): job_config: JobConfig @@ -92,6 +94,8 @@ def __init__(self, job_config: JobConfig): world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) + print(os.environ) + logger.error(os.environ) print(f"AHMAD: Parallelism: {parallel_dims} {world_size=}") logger.error(f" AHMAD: Parallelism: {parallel_dims} {world_size=}") dist_utils.init_distributed(job_config) @@ -99,15 +103,20 @@ def __init__(self, job_config: JobConfig): # build meshes self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type) if parallel_dims.dp_enabled: + print("AHMAD dp_enabled") dp_mesh = world_mesh["dp"] dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() else: + print("AHMAD dp_disabled") dp_degree, dp_rank = 1, 0 + print(self.world_mesh["dp_cp"]) + self.ft_manager = ft.init_ft_manager(job_config) # If TorchFT is enabled, the dp_rank and dp_degree, which are used for # dataloader must be changed. if self.ft_manager.enabled: + print("AHMAD: ft_manager enabled") dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) # take control of garbage collection to avoid stragglers @@ -325,6 +334,7 @@ def __init__(self, job_config: JobConfig): f"total steps {job_config.training.steps} " f"(warmup {job_config.lr_scheduler.warmup_steps})." ) + print(self.world_mesh["dp_cp"]) def batch_generator( self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] diff --git a/torchtitan/train_monarch.py b/torchtitan/train_monarch.py index c3a84e757..a8f87c5ab 100644 --- a/torchtitan/train_monarch.py +++ b/torchtitan/train_monarch.py @@ -7,6 +7,7 @@ import importlib import os import pickle +import threading import sys import time from datetime import timedelta @@ -23,32 +24,84 @@ from torchtitan.tools.logging import init_logger, logger from .train import Trainer +def pretend_you_are_torchrun(global_rank): + """ + Eventually, Monarch should handle all of this, but it's necessary for now because the job is + not running torchrun. Also there are already better ways to avoid hardcoding this, but + it's a demo and we'll live for now. + """ + # task_id = int(os.environ["TW_TASK_ID"]) + # global_rank = task_id * 8 + (global_rank % 8) + world_size = int(os.environ["WORLD_SIZE"]) + lr = min(world_size, global_rank % 8) + local_world_size = min(world_size, 8) + env = { + # "MASTER_ADDR": get_master_addr(), + # "MASTER_PORT": str(20101), + "RANK": str(global_rank), + "LOCAL_RANK": str(lr), + "LOCAL_WORLD_SIZE": str(local_world_size), + + "GROUP_RANK": str(0), + "GROUP_WORLD_SIZE": str(1), + + "ROLE_RANK": str(global_rank), + "ROLE_WORLD_SIZE": str(world_size), + "ROLE_NAME": "rank", + + # Note that WORLD_SIZE is already set. + } + os.environ.update(env) + class TrainerActorWrapper(Actor): def __init__(self, job_config: JobConfig): self.job_config = job_config self.rank = current_rank().rank - os.environ["RANK"] = str(self.rank) - os.environ["ROLE_RANK"] = str(self.rank) - os.environ["LOCAL_RANK"] = str(self.rank % 8) - world_size = int(os.environ["WORLD_SIZE"]) - os.environ["LOCAL_WORLD_SIZE"] = str(world_size % 8) - self.trainer = Trainer(self.job_config) + pretend_you_are_torchrun(self.rank) @endpoint def train(self): - self.trainer.train() - print("hello world") + print("Starting training") + pretend_you_are_torchrun(self.rank) + config = self.job_config + trainer: Optional[Trainer] = None + + try: + trainer = Trainer(config) + # trainer = self.trainer + tid = threading.get_native_id() + logger.error(f"AHMAD tid in train: {tid=}") + trainer.train() + + if config.checkpoint.create_seed_checkpoint: + assert ( + int(os.environ["WORLD_SIZE"]) == 1 + ), "Must create seed checkpoint using a single device, to disable sharding." + assert ( + config.checkpoint.enable_checkpoint + ), "Must enable checkpointing when creating a seed checkpoint." + trainer.checkpointer.save(curr_step=0, force=True) + logger.info("Created seed checkpoint") + else: + trainer.train() + finally: + if trainer: + trainer.close() + + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + logger.info("Process group destroyed.") + print("Done training") async def async_main(job_config: JobConfig): torch.use_deterministic_algorithms(True) world_size = int(os.environ["WORLD_SIZE"]) - # world_size = 2 local_proc_mesh = await proc_mesh( gpus=world_size, env={ "MASTER_ADDR": "localhost", - "MASTER_PORT": "12356", + "MASTER_PORT": "12358", }, ) print(job_config) From 8a8d5ea11a82f40b8fa45a0643275767967f1ff7 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 12 Jun 2025 13:00:59 -0700 Subject: [PATCH 3/5] . --- torchtitan/train.py | 17 ----------------- torchtitan/train_monarch.py | 2 +- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 1617e37e1..9340671d7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -32,8 +32,6 @@ maybe_enable_profiling, ) -logger.info = logger.warn - class Trainer(torch.distributed.checkpoint.stateful.Stateful): job_config: JobConfig @@ -94,29 +92,20 @@ def __init__(self, job_config: JobConfig): world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) - print(os.environ) - logger.error(os.environ) - print(f"AHMAD: Parallelism: {parallel_dims} {world_size=}") - logger.error(f" AHMAD: Parallelism: {parallel_dims} {world_size=}") dist_utils.init_distributed(job_config) # build meshes self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type) if parallel_dims.dp_enabled: - print("AHMAD dp_enabled") dp_mesh = world_mesh["dp"] dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() else: - print("AHMAD dp_disabled") dp_degree, dp_rank = 1, 0 - print(self.world_mesh["dp_cp"]) - self.ft_manager = ft.init_ft_manager(job_config) # If TorchFT is enabled, the dp_rank and dp_degree, which are used for # dataloader must be changed. if self.ft_manager.enabled: - print("AHMAD: ft_manager enabled") dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) # take control of garbage collection to avoid stragglers @@ -334,7 +323,6 @@ def __init__(self, job_config: JobConfig): f"total steps {job_config.training.steps} " f"(warmup {job_config.lr_scheduler.warmup_steps})." ) - print(self.world_mesh["dp_cp"]) def batch_generator( self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] @@ -454,8 +442,6 @@ def train_step( if not self.metrics_processor.should_log(self.step): return - print(f" AHMAD: {parallel_dims.dp_replicate_enabled=} {parallel_dims.dp_shard_enabled=} {parallel_dims.cp_enabled=} {self.ft_manager.enabled=}") - print(f" AHMAD: {self.world_mesh=}") if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled @@ -469,8 +455,6 @@ def train_step( and self.job_config.fault_tolerance.semi_sync_method is None ) ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None - print(f" ==== > AHMAD {self.world_mesh=} {ft_pg=}") - # print(self.world_mesh["dp_cp"]) global_avg_loss, global_max_loss = ( dist_utils.dist_mean(loss, self.world_mesh["dp_cp"], ft_pg), dist_utils.dist_max(loss, self.world_mesh["dp_cp"], ft_pg), @@ -548,7 +532,6 @@ def close(self) -> None: if __name__ == "__main__": - print(os.environ) init_logger() config_manager = ConfigManager() config = config_manager.parse_args() diff --git a/torchtitan/train_monarch.py b/torchtitan/train_monarch.py index a8f87c5ab..65897099f 100644 --- a/torchtitan/train_monarch.py +++ b/torchtitan/train_monarch.py @@ -101,7 +101,7 @@ async def async_main(job_config: JobConfig): gpus=world_size, env={ "MASTER_ADDR": "localhost", - "MASTER_PORT": "12358", + "MASTER_PORT": "12359", }, ) print(job_config) From daf506aab4c103d2ff1b87f2f775d15a1a71204a Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 12 Jun 2025 13:01:42 -0700 Subject: [PATCH 4/5] . --- torchtitan/distributed/parallel_dims.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 82aae2032..5f8bc5025 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -58,7 +58,6 @@ def build_mesh(self, device_type: str) -> DeviceMesh: ["pp", "dp_replicate", "dp_shard", "cp", "tp"], ): if d > 1: - # print(f"AHMAD HERE: {d} {name}") dims.append(d) names.append(name) @@ -101,7 +100,6 @@ def _build_mesh( mesh_dim_name="dp_shard_cp" ) if dp_cp_mesh_dim_names != []: - print(f" ==== *** AHMAD: dp_cp_mesh_dim_names {dp_cp_mesh_dim_names}") mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") return mesh From c5668252f8dd12bcdff7b2e567bf9177f0078cd1 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 12 Jun 2025 13:18:51 -0700 Subject: [PATCH 5/5] Repro bug --- torchtitan/train_monarch.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchtitan/train_monarch.py b/torchtitan/train_monarch.py index 65897099f..997218a65 100644 --- a/torchtitan/train_monarch.py +++ b/torchtitan/train_monarch.py @@ -59,6 +59,7 @@ def __init__(self, job_config: JobConfig): self.job_config = job_config self.rank = current_rank().rank pretend_you_are_torchrun(self.rank) + self.trainer = Trainer(config) @endpoint def train(self): @@ -66,9 +67,15 @@ def train(self): pretend_you_are_torchrun(self.rank) config = self.job_config trainer: Optional[Trainer] = None + repro_bug = True try: - trainer = Trainer(config) + # This works fine if we run trainer = Trainer(config) here + # and comment out the one in __init__() above. + # However, with this change, you should get an error like this: + # KeyError: "Invalid mesh_dim_names ('dp_cp',) specified. + trainer = self.trainer + # trainer = self.trainer tid = threading.get_native_id() logger.error(f"AHMAD tid in train: {tid=}")