diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 33d4dc17f..aade405b4 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -39,6 +39,12 @@ def estimate_memory(job_config: JobConfig): job_config.training.compile = False job_config.parallelism.enable_compiled_autograd = False + # init fake pg + store = FakeStore() + torch.distributed.init_process_group( + "fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store + ) + parallelism_config = job_config.parallelism parallel_dims = ParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, @@ -48,8 +54,9 @@ def estimate_memory(job_config: JobConfig): pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, world_size=world_size, - enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) + # ParallelDims.build_mesh has to happen outside of the FakeTensorMode + _ = parallel_dims.world_mesh # only FSDP and HSDP are supported if ( @@ -68,28 +75,21 @@ def estimate_memory(job_config: JobConfig): device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) - # init fake pg - store = FakeStore() - torch.distributed.init_process_group( - "fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store - ) - train_spec = get_train_spec(job_config.model.name) - # build meshes - world_mesh = parallel_dims.build_mesh(device_type="cuda") - # build tokenizer tokenizer = train_spec.build_tokenizer_fn(job_config) + loss_parallel_enabled = ( + parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel + ) train_context = dist_utils.get_train_context( - parallel_dims.loss_parallel_enabled, + loss_parallel_enabled, job_config.parallelism.enable_compiled_autograd, ) # build model (using meta init) - model_cls = train_spec.cls - model_args = train_spec.config[job_config.model.flavor] + model_args = train_spec.model_args[job_config.model.flavor] model_args.update_from_config(job_config, tokenizer) with ( @@ -101,14 +101,14 @@ def estimate_memory(job_config: JobConfig): f"Building {train_spec.name} {job_config.model.flavor} with {model_args}" ) with torch.device("meta"): - model = model_cls(model_args) + model = train_spec.model_cls(model_args) # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) # apply PT-D DP/TP parallelisms and activation checkpointing - train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) + train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device="cuda") if not active_fake_mode(): @@ -117,9 +117,7 @@ def estimate_memory(job_config: JobConfig): # build optimizer after applying parallelisms to the model ft_manager = init_ft_manager(job_config) - optimizers = build_optimizers( - [model], job_config, parallel_dims, world_mesh, ft_manager - ) + optimizers = build_optimizers([model], job_config, parallel_dims, ft_manager) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index ef31c1850..07966c276 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -106,14 +106,13 @@ def test_generate( # Tokenizer setup tokenizer = train_spec.build_tokenizer_fn(config) - model_cls = train_spec.cls - model_args = train_spec.config[config.model.flavor] + model_args = train_spec.model_args[config.model.flavor] model_args.update_from_config(config, tokenizer) init_device = "meta" if world_size > 1 else device with torch.device(init_device): logger.info(f"Init model on init_device: {init_device}") - model = model_cls(model_args) + model = train_spec.model_cls(model_args) world_mesh = None # Init distributed env @@ -127,14 +126,12 @@ def test_generate( pp=1, ep=1, world_size=world_size, - enable_loss_parallel=False, ) - # Build world mesh for parallelism - world_mesh = parallel_dims.build_mesh(device_type=device_type) + world_mesh = parallel_dims.world_mesh # apply_tp (with Sequence Parallel) on unevenly sharded # sequences would require https://github.com/pytorch/torchtitan/pull/686 - apply_tp_minus_sp(model, world_mesh["tp"]) + apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"]) dist_utils.set_determinism(world_mesh, device, seed, deterministic) diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py index 704e81a91..6b9d9515f 100644 --- a/tests/unit_tests/test_model_converter.py +++ b/tests/unit_tests/test_model_converter.py @@ -23,7 +23,6 @@ def build_parallel_dims(job_config, world_size): pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, world_size=world_size, - enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) return parallel_dims diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index c364af385..5b0145477 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -9,12 +9,14 @@ import pytest import torch import torch.nn as nn +from torchtitan.components.ft import FTManager from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers, OptimizersContainer from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.config_manager import JobConfig from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.models.llama3 import parallelize_llama, pipeline_llama from torchtitan.protocols.train_spec import ( apply_to_train_specs, @@ -39,7 +41,10 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: def fake_build_optimizers( - model_parts: list[nn.Module], job_config: JobConfig + model_parts: list[nn.Module], + job_config: JobConfig, + parallel_dims: ParallelDims, + ft_manager: FTManager, ) -> OptimizersContainer: optimizer_kwargs = { "lr": 0.1, @@ -57,11 +62,11 @@ def fake_build_optimizers( class TestTrainSpec: def test_register_train_spec(self): - fake_config = {"fake": None} + fake_config = {"fake": BaseModelArgs()} spec = TrainSpec( name="fake", - cls=FakeModel, - config=fake_config, + model_cls=FakeModel, + model_args=fake_config, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, @@ -78,11 +83,11 @@ def test_register_train_spec(self): new_spec = get_train_spec("fake2") def test_optim_hook(self): - fake_config = {"fake": None} + fake_config = {"fake": BaseModelArgs()} spec = TrainSpec( name="fake2", - cls=FakeModel, - config=fake_config, + model_cls=FakeModel, + model_args=fake_config, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=fake_build_optimizers, @@ -111,21 +116,27 @@ def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec: original_build_optimizers_fn = spec.build_optimizers_fn def my_build_optimizer_fn( - model_parts: list[nn.Module], job_config: JobConfig + model_parts: list[nn.Module], + job_config: JobConfig, + parallel_dims: ParallelDims, + ft_manager: FTManager, ) -> OptimizersContainer: - optimizers = original_build_optimizers_fn(model_parts, job_config) + optimizers = original_build_optimizers_fn( + model_parts, job_config, parallel_dims, ft_manager + ) optimizers.register_step_post_hook( partial(my_hook, model_parts=model_parts) ) return optimizers spec.build_optimizers_fn = my_build_optimizer_fn + return spec apply_to_train_specs(register_optimizer_hook_to_spec) - model = new_spec.cls(BaseModelArgs()) + model = new_spec.model_cls(BaseModelArgs()) model_parts = [model] - optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig()) + optimizers = new_spec.build_optimizers_fn(model_parts, None, None, None) assert optimizers.optimizers[0].__class__.__name__ == "Adam" batch = torch.randn(8, 8) model(batch).sum().backward() diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index ff055cbe7..1bc07f2f2 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -26,8 +26,8 @@ ) from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType from torch.distributed.checkpoint.stateful import Stateful -from torch.utils.data import DataLoader +from torchtitan.components.dataloader import BaseDataLoader from torchtitan.components.ft import FTManager from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer @@ -180,17 +180,19 @@ class CheckpointManager: def __init__( self, - dataloader: DataLoader, + dataloader: BaseDataLoader | None, model_parts: list[nn.Module], optimizers: OptimizersContainer, lr_schedulers: LRSchedulersContainer, states: dict[str, Any], job_config: JobConfig, - ft_manager: FTManager, + ft_manager: FTManager | None = None, ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint - self.ft_manager = ft_manager.manager if ft_manager.enabled else None + self.ft_manager = ( + ft_manager.manager if ft_manager and ft_manager.enabled else None + ) if self.ft_manager: optimizers.init_cache_state_dict() diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index cd3604f29..d2ff514cf 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -15,7 +15,6 @@ StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful -from torch.distributed.device_mesh import DeviceMesh from torch.optim import Optimizer from torchtitan.components.ft import FTManager, has_torchft @@ -244,7 +243,6 @@ def build_optimizers( model_parts: list[nn.Module], job_config: JobConfig, parallel_dims: ParallelDims, - world_mesh: DeviceMesh, ft_manager: FTManager, ) -> OptimizersContainer: """Create a OptimizersContainer for the given model parts and job config. diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index 45ecf34f9..6ca11d671 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -7,17 +7,15 @@ import json -import logging import os from abc import ABC, abstractmethod from typing import Any, Optional, Union from tokenizers import AddedToken, Tokenizer from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger from typing_extensions import override -logger = logging.getLogger(__name__) - class BaseTokenizer(ABC): # base tokenizer interface, for typing purpose mainly diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 904c65ca5..7f678514b 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -50,14 +50,12 @@ def __init__( dp_rank: int, tokenizer: BaseTokenizer, parallel_dims: ParallelDims, - world_mesh: torch.distributed.DeviceMesh, loss_fn: LossFunction, validation_context: Generator[None, None, None], maybe_enable_amp: Generator[None, None, None], ): self.job_config = job_config self.parallel_dims = parallel_dims - self.world_mesh = world_mesh self.loss_fn = loss_fn self.validation_dataloader = build_hf_validation_dataloader( job_config=job_config, @@ -78,6 +76,8 @@ def validate( model = model_parts[0] model.eval() + parallel_dims = self.parallel_dims + accumulated_losses = [] device_type = utils.device_type num_steps = 0 @@ -96,13 +96,13 @@ def validate( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=self.world_mesh["cp"], + cp_mesh=parallel_dims.world_mesh["cp"], cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, ) - if self.parallel_dims.cp_enabled + if parallel_dims.cp_enabled else None ) @@ -119,8 +119,10 @@ def validate( # Compute average loss loss = torch.sum(torch.stack(accumulated_losses)) loss /= num_steps - if self.parallel_dims.dp_cp_enabled: - global_avg_loss = dist_utils.dist_mean(loss, self.world_mesh["dp_cp"]) + if parallel_dims.dp_cp_enabled: + global_avg_loss = dist_utils.dist_mean( + loss, parallel_dims.world_mesh["dp_cp"] + ) else: global_avg_loss = loss @@ -144,7 +146,6 @@ def build_validator( dp_rank: int, tokenizer: BaseTokenizer, parallel_dims: ParallelDims, - world_mesh: torch.distributed.DeviceMesh, loss_fn: LossFunction, validation_context: Generator[None, None, None], maybe_enable_amp: Generator[None, None, None], @@ -156,7 +157,6 @@ def build_validator( dp_rank=dp_rank, tokenizer=tokenizer, parallel_dims=parallel_dims, - world_mesh=world_mesh, loss_fn=loss_fn, validation_context=validation_context, maybe_enable_amp=maybe_enable_amp, diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 08986b220..01e14cc0b 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -10,6 +10,7 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torchtitan.tools.logging import logger +from torchtitan.tools.utils import device_type __all__ = ["ParallelDims"] @@ -24,7 +25,8 @@ class ParallelDims: pp: int ep: int world_size: int - enable_loss_parallel: bool + + _world_mesh: DeviceMesh = None def __post_init__(self): self._validate() @@ -55,16 +57,16 @@ def _validate(self): # EP would borrow all cp and some dp_shard degree assert ep % cp == 0 and (dp_shard * cp) % ep == 0 - def build_mesh(self, device_type: str) -> DeviceMesh: + def build_mesh(self) -> DeviceMesh: # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel # is not very clean, due to the limited support from DeviceMesh # for creating two staggered meshes. Will improve. if self.ep > 1: - return self._build_mesh_with_ep(device_type) + return self._build_mesh_with_ep() else: - return self._build_mesh_without_ep(device_type) + return self._build_mesh_without_ep() - def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh: + def _build_mesh_with_ep(self) -> DeviceMesh: # With ep, dp_shard and ep are derived submeshes: # dp_shard = dp_shard_mod_ep * dp_shard_in_ep # ep = dp_shard_in_ep * cp @@ -128,7 +130,7 @@ def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh: return mesh - def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh: + def _build_mesh_without_ep(self) -> DeviceMesh: dims = [] names = [] for d, name in zip( @@ -173,6 +175,14 @@ def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh: return mesh + @property + def world_mesh(self) -> str: + # doing late init so ParallelDims can still be used as a lightweight + # dataclass without having to initialize the world mesh + if self._world_mesh is None: + self._world_mesh = self.build_mesh() + return self._world_mesh + @property def dp_enabled(self): return self.dp_replicate > 1 or self.dp_shard > 1 @@ -206,18 +216,24 @@ def pp_enabled(self): return self.pp > 1 @property - def loss_parallel_enabled(self): - return self.tp > 1 and self.enable_loss_parallel + def ep_enabled(self): + return self.ep > 1 @cached_property def non_data_parallel_size(self): return self.cp * self.tp * self.pp - @property - def ep_enabled(self): - return self.ep > 1 + @cached_property + def seq_len_divisor(self): + # Sequence Parallel requires that seq_len be divisible by TP degree. + # https://github.com/pytorch/torchtitan/pull/640#discussion_r1849481001 - @property + # Context Parallel requires that seq_len be divisible by 2 * CP degree, + # when load balancing is enabled (by default). + # https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246 + return self.tp * (self.cp * 2) + + @cached_property def dense_params_mesh_ndim(self): - # Note: EP params mesh ndim is 1 more due to the 'ep' mesh + # Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 3f824d5fe..58c5df0ca 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -307,7 +307,7 @@ def clip_grad_norm_( error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: DeviceMesh | None = None, - parallel_dims: ParallelDims | None = None, + ep_dense_params_mesh_ndim: int | None = None, ) -> torch.Tensor: """ Clip the gradient norm of an iterable of parameters. @@ -329,14 +329,15 @@ def clip_grad_norm_( If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types. Default: ``None`` - pp_mesh: pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages. - parallel_dims: ParallelDims object which contains Expert Parallel related info. + pp_mesh: Pipeline Parallel device mesh. If not None, will reduce gradient norm across PP stages. + ep_dense_params_mesh_ndim: Mesh ndim of the dense params when EP is used. If EP is not used, + set it to ``None``. Returns: Total norm of the parameter gradients (viewed as a single vector). """ - if parallel_dims and parallel_dims.ep_enabled: + if ep_dense_params_mesh_ndim is not None: return _clip_grad_norm_with_ep( parameters, max_norm, @@ -344,7 +345,7 @@ def clip_grad_norm_( error_if_nonfinite, foreach, pp_mesh, - parallel_dims, + ep_dense_params_mesh_ndim, ) if isinstance(parameters, torch.Tensor): @@ -388,10 +389,8 @@ def _clip_grad_norm_with_ep( error_if_nonfinite: bool, foreach: bool | None, pp_mesh: DeviceMesh | None, - parallel_dims: ParallelDims, + dense_params_mesh_ndim: int, ) -> torch.Tensor: - assert parallel_dims.ep_enabled - ep_params = [] non_ep_params = [] ep_grads = [] @@ -401,7 +400,7 @@ def _clip_grad_norm_with_ep( if p.grad is None: continue assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) - if p.device_mesh.ndim == parallel_dims.dense_params_mesh_ndim: + if p.device_mesh.ndim == dense_params_mesh_ndim: non_ep_params.append(p) non_ep_grads.append(p.grad) else: diff --git a/torchtitan/experiments/deepseek_v3/__init__.py b/torchtitan/experiments/deepseek_v3/__init__.py index eb515bcfc..f93d0d80e 100644 --- a/torchtitan/experiments/deepseek_v3/__init__.py +++ b/torchtitan/experiments/deepseek_v3/__init__.py @@ -42,8 +42,8 @@ register_train_spec( TrainSpec( name="deepseek3", - cls=DeepseekForCausalLM, - config=deepseek_configs, + model_cls=DeepseekForCausalLM, + model_args=deepseek_configs, parallelize_fn=parallelize_deepseek, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/experiments/deepseek_v3/train_ds_real.py b/torchtitan/experiments/deepseek_v3/train_ds_real.py index 398360a6e..be4a92da5 100644 --- a/torchtitan/experiments/deepseek_v3/train_ds_real.py +++ b/torchtitan/experiments/deepseek_v3/train_ds_real.py @@ -155,8 +155,8 @@ def run_full_model( pp=pp_size, cp=1, tp=1, + ep=1, world_size=world_mesh.size(), - enable_loss_parallel=False, ) metrics_processor = build_metrics_processor( @@ -180,7 +180,7 @@ def run_full_model( loss_fn = cross_entropy_loss # torch.nn.functional.cross_entropy ft_manager = ft.init_ft_manager(config) - optimizer = build_optimizers([model], config, ft_manager) + optimizer = build_optimizers([model], config, proxy_parallel_dims, ft_manager) lr_scheduler = build_lr_schedulers(optimizer, config) diff --git a/torchtitan/experiments/flux/__init__.py b/torchtitan/experiments/flux/__init__.py index 5fe8ba3ee..12613a793 100644 --- a/torchtitan/experiments/flux/__init__.py +++ b/torchtitan/experiments/flux/__init__.py @@ -108,8 +108,8 @@ register_train_spec( TrainSpec( name="flux", - cls=FluxModel, - config=flux_configs, + model_cls=FluxModel, + model_args=flux_configs, parallelize_fn=parallelize_flux, pipelining_fn=None, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/experiments/flux/infra/parallelize.py b/torchtitan/experiments/flux/infra/parallelize.py index 460c7f588..69fef68c5 100644 --- a/torchtitan/experiments/flux/infra/parallelize.py +++ b/torchtitan/experiments/flux/infra/parallelize.py @@ -21,7 +21,6 @@ def parallelize_flux( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -36,7 +35,7 @@ def parallelize_flux( apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], cpu_offload=job_config.training.enable_cpu_offload, @@ -117,7 +116,6 @@ def apply_ac(model: nn.Module, ac_config): def parallelize_encoders( t5_model: nn.Module, clip_model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -132,7 +130,7 @@ def parallelize_encoders( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) fsdp_config = { - "mesh": world_mesh[tuple(dp_mesh_dim_names)], + "mesh": parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], "mp_policy": mp_policy, } if job_config.training.enable_cpu_offload: diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 269abe1c5..c328d12b7 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -36,7 +36,7 @@ def __init__(self, job_config: JobConfig): # (mainly for debugging, expect perf loss). # For Flux model, we need distinct seed across FSDP ranks to ensure we randomly dropout prompts info in dataloader dist_utils.set_determinism( - self.world_mesh, + self.parallel_dims.world_mesh, self.device, job_config.training.seed, job_config.training.deterministic, @@ -54,11 +54,11 @@ def __init__(self, job_config: JobConfig): ) # load components - model_config = self.train_spec.config[job_config.model.flavor] + model_args = self.train_spec.model_args[job_config.model.flavor] self.autoencoder = load_ae( job_config.encoder.autoencoder_path, - model_config.autoencoder_params, + model_args.autoencoder_params, device=self.device, dtype=self._dtype, random_init=job_config.training.test_mode, @@ -77,7 +77,6 @@ def __init__(self, job_config: JobConfig): self.t5_encoder, self.clip_encoder = parallelize_encoders( t5_model=self.t5_encoder, clip_model=self.clip_encoder, - world_mesh=self.world_mesh, parallel_dims=self.parallel_dims, job_config=job_config, ) diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index 9f7affc09..798555ae4 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -94,8 +94,8 @@ register_train_spec( TrainSpec( name="llama4", - cls=Transformer, - config=llama4_configs, + model_cls=Transformer, + model_args=llama4_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_llama4_optimizers, diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index d681cd6a1..1b6201128 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -38,7 +38,6 @@ def parallelize_llama( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -49,6 +48,16 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ if parallel_dims.tp_enabled: if ( @@ -71,7 +80,7 @@ def parallelize_llama( apply_non_moe_tp( model, world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, + loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py index d4829de88..11870f5fe 100644 --- a/torchtitan/experiments/llama4/optimizer.py +++ b/torchtitan/experiments/llama4/optimizer.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -from torch.distributed.device_mesh import DeviceMesh from torchtitan.components.ft import FTManager from torchtitan.components.optimizer import build_optimizers, OptimizersContainer @@ -17,10 +16,11 @@ # for MoE auxiliary-loss-free load balancing def _update_expert_bias( model_parts: list[nn.Module], - world_mesh: dict[str, DeviceMesh], parallel_dims: ParallelDims, ): - dp_cp_mesh = world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + dp_cp_mesh = ( + parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + ) # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. for model_part in model_parts: @@ -48,20 +48,18 @@ def build_llama4_optimizers( model_parts: list[nn.Module], job_config: JobConfig, parallel_dims: ParallelDims, - world_mesh: DeviceMesh, ft_manager: FTManager, ) -> OptimizersContainer: optimizers = build_optimizers( model_parts=model_parts, job_config=job_config, parallel_dims=parallel_dims, - world_mesh=world_mesh, ft_manager=ft_manager, ) optimizers.register_step_pre_hook( lambda *args, **kwargs: _update_expert_bias( - model_parts, world_mesh=world_mesh, parallel_dims=parallel_dims + model_parts, parallel_dims=parallel_dims ) ) diff --git a/torchtitan/experiments/multimodal/__init__.py b/torchtitan/experiments/multimodal/__init__.py index f3ba2a2d4..bbb37d5c5 100644 --- a/torchtitan/experiments/multimodal/__init__.py +++ b/torchtitan/experiments/multimodal/__init__.py @@ -24,8 +24,8 @@ register_train_spec( TrainSpec( name="llama4_multimodal", - cls=MultimodalDecoder, - config=llama4_mm_configs, + model_cls=MultimodalDecoder, + model_args=llama4_mm_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/experiments/simple_fsdp/__init__.py b/torchtitan/experiments/simple_fsdp/__init__.py index 80a2b3c3a..2b578dd4b 100644 --- a/torchtitan/experiments/simple_fsdp/__init__.py +++ b/torchtitan/experiments/simple_fsdp/__init__.py @@ -20,8 +20,8 @@ register_train_spec( TrainSpec( name="llama3_simple_fsdp", - cls=SimpleFSDPTransformer, - config=llama3_configs, + model_cls=SimpleFSDPTransformer, + model_args=llama3_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/parallelize.py index c386fd3d3..7a94adea3 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/parallelize.py @@ -7,8 +7,6 @@ import torch import torch.nn as nn -from torch.distributed import DeviceMesh - from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_tp @@ -19,7 +17,6 @@ def parallelize_llama( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -30,6 +27,16 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel @@ -48,11 +55,11 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise - tp_mesh = world_mesh["tp"] + tp_mesh = parallel_dims.world_mesh["tp"] apply_tp( model, - world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, + tp_mesh, + loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) @@ -84,7 +91,7 @@ def parallelize_llama( model = data_parallel( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], mode=dp_mode, ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, diff --git a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py index 3c15ce573..428182655 100644 --- a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py +++ b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py @@ -38,10 +38,10 @@ def init_test(self): cp=1, tp=1, pp=1, + ep=1, world_size=self.world_size, - enable_loss_parallel=True, ) - self.device_mesh = self.parallel_dims.build_mesh(device_type="cuda") + self.device_mesh = self.parallel_dims.world_mesh def get_input(self): inputs = torch.randn(8, 8).cuda() diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index e86917bbc..141b740ce 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -113,8 +113,8 @@ register_train_spec( TrainSpec( name="deepseek_v3", - cls=DeepSeekV3Model, - config=deepseekv3_configs, + model_cls=DeepSeekV3Model, + model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, pipelining_fn=None, build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 44e0bc6bb..1ba45f86d 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -26,10 +26,20 @@ # Adapted from llama4/infra/parallelize.py def parallelize_deepseekv3( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + if parallel_dims.tp_enabled: if job_config.parallelism.enable_async_tensor_parallel: # TODO(jianiw): This branch needs to be tested and enabled @@ -54,7 +64,7 @@ def parallelize_deepseekv3( apply_non_moe_tp( model, world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, + loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, enable_async_tp=False, ) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 2e9a11d47..26895274c 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -73,8 +73,8 @@ register_train_spec( TrainSpec( name="llama3", - cls=Transformer, - config=llama3_configs, + model_cls=Transformer, + model_args=llama3_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index df395adcb..d67e28372 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -34,7 +34,6 @@ def parallelize_llama( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -45,16 +44,15 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - # TODO: TP currently cannot handle uneven seq_len because we set `use_local_output=True` - # (to use plain Tensors), which was because of the bug in computation of complex - # numbers with DTensors when setting `use_local_output=False`. - # See https://github.com/pytorch/pytorch/issues/130646 and - # https://github.com/pytorch/torchtitan/issues/1306 for details. + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. assert ( - job_config.training.seq_len % (parallel_dims.tp * parallel_dims.cp) == 0 + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree - ({parallel_dims.tp}) and CP degree ({parallel_dims.cp}). + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ if parallel_dims.tp_enabled: @@ -78,7 +76,7 @@ def parallelize_llama( apply_tp( model, world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, + loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index 7ad73a229..dfb424b5b 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -8,6 +8,7 @@ import copy +import torch import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage @@ -25,7 +26,7 @@ generate_split_points, stage_ids_this_rank, ) -from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction +from torchtitan.protocols.train_spec import ParallelizeFunction from torchtitan.tools.logging import logger from ..model.args import TransformerModelArgs @@ -33,15 +34,14 @@ def pipeline_llama( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, - device: DeviceType, + device: torch.device, model_config: TransformerModelArgs, parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = world_mesh["pp"] + pp_mesh = parallel_dims.world_mesh["pp"] stages, model_parts = pipeline_llama_manual_split( model, pp_mesh, parallel_dims, job_config, device, model_config @@ -52,7 +52,7 @@ def pipeline_llama( # optimizer, and checkpointing for i, m in enumerate(model_parts): # apply SPMD-style PT-D techniques - m = parallelize_fn(m, world_mesh, parallel_dims, job_config) + m = parallelize_fn(m, parallel_dims, job_config) model_parts[i] = m # NOTE: this is to update the model in the stage # in case the model is modified e.g. by torch.compile @@ -77,7 +77,7 @@ def pipeline_llama_manual_split( pp_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, - device: DeviceType, + device: torch.device, model_config: TransformerModelArgs, ) -> tuple[list[PipelineStage], list[nn.Module]]: """ diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index e7caa89f0..3ee870771 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -7,13 +7,12 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. from abc import abstractmethod -from collections.abc import Callable, Mapping +from collections.abc import Callable from dataclasses import dataclass from typing import Protocol, TypeAlias import torch import torch.nn as nn -from torch.distributed.device_mesh import DeviceMesh from torch.distributed.pipelining.schedules import _PipelineSchedule from torchtitan.components.dataloader import BaseDataLoader @@ -27,8 +26,6 @@ from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims -DeviceType = int | str | torch.device - @dataclass class BaseModelArgs: @@ -65,6 +62,11 @@ def __init__(self, model_args: BaseModelArgs) -> None: @abstractmethod def init_weights(self, buffer_device: torch.device | None = None) -> None: + """Initialize model weights. + + Args: + buffer_device: Optional device to place buffers on during initialization. + """ pass @@ -76,7 +78,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: TokenizerBuilder: TypeAlias = Callable[..., BaseTokenizer] MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] OptimizersBuilder: TypeAlias = Callable[ - [list[nn.Module], JobConfig, ParallelDims, DeviceMesh, FTManager], + [list[nn.Module], JobConfig, ParallelDims, FTManager], OptimizersContainer, ] LRSchedulersBuilder: TypeAlias = Callable[ @@ -89,8 +91,8 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: @dataclass class TrainSpec: name: str - cls: type[nn.Module] - config: Mapping[str, BaseModelArgs] + model_cls: type[ModelProtocol] + model_args: dict[str, BaseModelArgs] parallelize_fn: ParallelizeFunction pipelining_fn: PipeliningFunction | None build_optimizers_fn: OptimizersBuilder diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index aaa0da8f8..4f10a088a 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -23,10 +23,8 @@ def has_cuda_capability(major: int, minor: int) -> bool: ) -def get_device_info(): - device_type = _get_available_device_type() - if device_type is None: - device_type = "cuda" # default device_type: cuda +def get_device_info() -> tuple[str, torch.device]: + device_type = _get_available_device_type() or "cuda" device_module = _get_device_module(device_type) # default device_module:torch.cuda return device_type, device_module diff --git a/torchtitan/train.py b/torchtitan/train.py index 3dc8a61b2..ea7a8e2ef 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -34,32 +34,33 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): + # core configs job_config: JobConfig - gc_handler: utils.GarbageCollection - parallel_dims: ParallelDims train_spec: train_spec_module.TrainSpec - world_mesh: torch.distributed.DeviceMesh - gradient_accumulation_steps: int + # swappable training components in TrainSpec dataloader: train_spec_module.BaseDataLoader - metrics_processor: train_spec_module.MetricsProcessor - checkpointer: CheckpointManager - train_context: Generator[None, None, None] - model_parts: list[torch.nn.Module] loss_fn: train_spec_module.LossFunction optimizers: train_spec_module.OptimizersContainer lr_schedulers: train_spec_module.LRSchedulersContainer + validator: train_spec_module.BaseValidator + metrics_processor: train_spec_module.MetricsProcessor - validator: train_spec_module.BaseValidator | None + # non-swappable training components + checkpointer: CheckpointManager + ft_manager: ft.FTManager + # runtime utilities + device: torch.device + gc_handler: utils.GarbageCollection + train_context: Generator[None, None, None] + gradient_accumulation_steps: int pp_has_first_stage: bool pp_has_last_stage: bool - device: torch.device - - # states + # additional training states step: int # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @@ -82,7 +83,8 @@ def __init__(self, job_config: JobConfig): # Device has to be set before creating TorchFT manager. device_module.set_device(self.device) - # init distributed + # init distributed and build meshes + dist_utils.init_distributed(job_config) world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism self.parallel_dims = parallel_dims = ParallelDims( @@ -93,12 +95,9 @@ def __init__(self, job_config: JobConfig): pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, world_size=world_size, - enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) - dist_utils.init_distributed(job_config) - # build meshes - self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type) + world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() @@ -141,8 +140,7 @@ def __init__(self, job_config: JobConfig): ) # build model (using meta init) - model_cls = self.train_spec.cls - model_args = self.train_spec.config[job_config.model.flavor] + model_args = self.train_spec.model_args[job_config.model.flavor] # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) @@ -150,7 +148,7 @@ def __init__(self, job_config: JobConfig): f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) with torch.device("meta"): - model = model_cls(model_args) + model = self.train_spec.model_cls(model_args) # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) @@ -231,7 +229,6 @@ def __init__(self, job_config: JobConfig): self.pp_has_last_stage, ) = self.train_spec.pipelining_fn( model, - world_mesh, parallel_dims, job_config, self.device, @@ -253,9 +250,7 @@ def __init__(self, job_config: JobConfig): ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - model = self.train_spec.parallelize_fn( - model, world_mesh, parallel_dims, job_config - ) + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device=init_device) with torch.no_grad(): @@ -283,7 +278,7 @@ def __init__(self, job_config: JobConfig): # build optimizer after applying parallelisms to the model self.optimizers = self.train_spec.build_optimizers_fn( - self.model_parts, job_config, parallel_dims, world_mesh, self.ft_manager + self.model_parts, job_config, parallel_dims, self.ft_manager ) self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( self.optimizers, job_config @@ -312,8 +307,11 @@ def __init__(self, job_config: JobConfig): ft_manager=self.ft_manager, ) + loss_parallel_enabled = ( + parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel + ) self.train_context = dist_utils.get_train_context( - parallel_dims.loss_parallel_enabled, + loss_parallel_enabled, parallelism_config.enable_compiled_autograd, ) self.maybe_enable_amp = dist_utils.maybe_enable_amp( @@ -335,7 +333,6 @@ def __init__(self, job_config: JobConfig): dp_rank=dp_rank, tokenizer=tokenizer, parallel_dims=parallel_dims, - world_mesh=world_mesh, loss_fn=self.train_spec.build_loss_fn(job_config), validation_context=self.train_context, maybe_enable_amp=self.maybe_enable_amp, @@ -391,7 +388,7 @@ def forward_backward_step( inputs = input_dict["input"] optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=self.world_mesh["cp"], + cp_mesh=parallel_dims.world_mesh["cp"], cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -457,8 +454,14 @@ def train_step( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, - pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, - parallel_dims=parallel_dims, + pp_mesh=( + parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + ), + ep_dense_params_mesh_ndim=( + parallel_dims.dense_params_mesh_ndim + if parallel_dims.ep_enabled + else None + ), ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() @@ -480,8 +483,8 @@ def train_step( ) ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, self.world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, self.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), ) else: global_avg_loss = global_max_loss = loss.detach().item() @@ -546,14 +549,13 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.world_mesh, + world_mesh=self.parallel_dims.world_mesh, ) if torch.distributed.get_rank() == 0: logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) - self.metrics_processor.close() logger.info("Training completed") def state_dict(self) -> dict[str, Any]: @@ -565,6 +567,8 @@ def load_state_dict(self, state_dict: dict[str, Any]): def close(self) -> None: if self.checkpointer: self.checkpointer.close() + if self.metrics_processor: + self.metrics_processor.close() if __name__ == "__main__":