From 18dd9dd8f6e560be46e8594b0623abcb12030785 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Sun, 2 Feb 2025 00:30:13 -0800 Subject: [PATCH 01/16] debugging attention shape --- recipes/configs/llama3_3/70B_full.yaml | 4 +++ recipes/full_finetune_distributed.py | 46 +++++++++++++++++++------- torchtune/training/_distributed.py | 29 +++++++++++++++- 3 files changed, 66 insertions(+), 13 deletions(-) diff --git a/recipes/configs/llama3_3/70B_full.yaml b/recipes/configs/llama3_3/70B_full.yaml index 4880a89edc..a1c2b3b1b2 100644 --- a/recipes/configs/llama3_3/70B_full.yaml +++ b/recipes/configs/llama3_3/70B_full.yaml @@ -18,6 +18,10 @@ output_dir: /tmp/torchtune/llama3_3_70B/full # /tmp may be deleted by your system. Change it to your preference. +tensor_parallel_dim: 2 +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 34ad48e938..b3212bfd00 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -16,6 +16,7 @@ from torch import nn from torch.distributed import destroy_process_group, init_process_group +from torch.distributed.tensor.parallel import parallelize_module from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler @@ -25,6 +26,7 @@ from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training._distributed import build_device_mesh from torchtune.training.activations import apply_selective_activation_checkpointing from torchtune.training.checkpointing._checkpoint_client import ( CheckpointClient, @@ -137,8 +139,11 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False - _, rank = utils.get_world_size_and_rank() - self._is_rank_zero = rank == 0 + # Distributed variables + self.world_size, self.rank = utils.get_world_size_and_rank() + self._is_rank_zero = self.rank == 0 + self.parallelize_plan = config.instantiate(cfg.get("parallelize_plan", None)) + self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", None) # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint @@ -521,7 +526,27 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - # For FSDP sharding + device_mesh = build_device_mesh( + self._device.type, + data_parallel_dim=self.world_size // self.tensor_parallel_dim, + tensor_parallel_dim=self.tensor_parallel_dim, + ) + + # Apply tensor parallelism to the model + if self.tensor_parallel_dim is not None and self.tensor_parallel_dim > 1: + if self.parallelize_plan is None: + raise ValueError("Parallelism plan need to be provided when tensor parallel is enabled.") + tp_mesh = device_mesh["tp"] + print(f"{tp_mesh.size()=}") + # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel + # training.prepare_mha_for_tp(model, tp_mesh) + parallelize_module( + model, + tp_mesh, + parallelize_plan=self.parallelize_plan, + ) + + # Apply Fully Sharded Data Parallelism to the model fsdp_shard_conditions = [ partial( training.get_shard_conditions, @@ -533,6 +558,7 @@ def _setup_model( shard_conditions=fsdp_shard_conditions, cpu_offload=fsdp_cpu_offload, reshard_after_forward=reshard_after_forward, + dp_mesh=device_mesh["dp"], ) with training.set_default_dtype(self._dtype), self._device: @@ -638,8 +664,6 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = utils.get_world_size_and_rank() - if isinstance(cfg_dataset, ListConfig): datasets = [ config.instantiate(single_cfg_dataset, self._tokenizer) @@ -657,7 +681,7 @@ def _setup_data( collate_fn = _get_component_from_path(collate_fn) sampler = DistributedSampler( - ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ds, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, seed=0 ) dataloader = DataLoader( dataset=ds, @@ -687,8 +711,6 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - world_size, rank = utils.get_world_size_and_rank() - # zero out the gradients before starting training if not self._optimizer_in_bwd: self._optimizer.zero_grad() @@ -708,7 +730,7 @@ def train(self) -> None: # in case shuffle is True self._sampler.set_epoch(curr_epoch) - pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero) for idx, batch in enumerate(self._dataloader): if ( self.max_steps_per_epoch is not None @@ -769,7 +791,7 @@ def train(self) -> None: torch.distributed.all_reduce(running_loss) # We multiply by world_size to undo FSDP2 gradient normalization. - current_loss = current_loss * (world_size / num_tokens) + current_loss = current_loss * (self.world_size / num_tokens) current_loss.backward() @@ -782,7 +804,7 @@ def train(self) -> None: torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens # We multiply by world_size to undo FSDP2 gradient normalization. - training.scale_grads(self._model, world_size / num_tokens) + training.scale_grads(self._model, self.world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), @@ -820,7 +842,7 @@ def train(self) -> None: ), ), "tokens_per_second_per_gpu": num_tokens - / (time_per_step * world_size), + / (time_per_step * self.world_size), } if self._log_peak_memory_stats: log_dict.update( diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 0e33bfd118..7322111575 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -25,7 +25,7 @@ set_optimizer_state_dict, StateDictOptions, ) -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import ShardingStrategy from torch.nn.modules.module import _IncompatibleKeys from torch.optim import Optimizer @@ -502,6 +502,28 @@ def get_shard_conditions( return False +def build_device_mesh( + device_type: str, + *, + data_parallel_dim: Optional[int] = None, + tensor_parallel_dim: Optional[int] = None, +) -> DeviceMesh: + if not any([data_parallel_dim, tensor_parallel_dim]): + raise ValueError( + "At least one of data_parallel_dim or tensor_parallel_dim must be specified" + ) + + valid_dim = [] + valid_dim_names = [] + if data_parallel_dim is not None and data_parallel_dim >= 0: + valid_dim.append(data_parallel_dim) + valid_dim_names.append("dp") + if tensor_parallel_dim is not None and tensor_parallel_dim >= 0: + valid_dim.append(tensor_parallel_dim) + valid_dim_names.append("tp") + + return init_device_mesh(device_type, valid_dim, mesh_dim_names=valid_dim_names) + def shard_model( model: TransformerDecoder, @@ -509,6 +531,7 @@ def shard_model( *, cpu_offload: bool, reshard_after_forward: bool = True, + dp_mesh: Optional[DeviceMesh] = None, ) -> None: """ Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API. @@ -527,6 +550,8 @@ def shard_model( reshard_after_forward (bool): Whether to reshard parameters and buffers after the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy. + dp_mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under mutliple parallelism. + Default to None. Raises: ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. @@ -534,6 +559,8 @@ def shard_model( fsdp_kwargs = {"reshard_after_forward": reshard_after_forward} if cpu_offload: fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() + if dp_mesh: + fsdp_kwargs["mesh"] = dp_mesh # Shard the model with FSDP, iterating in reverse to start with # lowest-level modules first From a6251757e45907b6b144a97cc7ce217f7f31256c Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Sun, 2 Feb 2025 11:55:04 -0800 Subject: [PATCH 02/16] working --- recipes/configs/llama3/8B_full.yaml | 6 ++- recipes/configs/llama3_3/70B_full.yaml | 4 -- recipes/full_finetune_distributed.py | 54 ++++++++++++++------------ torchtune/training/_distributed.py | 37 +++++------------- 4 files changed, 44 insertions(+), 57 deletions(-) diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 4c23331171..870d6a4907 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -20,6 +20,10 @@ output_dir: /tmp/torchtune/llama3_8B/full # /tmp may be deleted by your system. Change it to your preference. +tensor_parallel_dim: 4 +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer @@ -55,7 +59,7 @@ epochs: 1 optimizer: _component_: torch.optim.AdamW lr: 2e-5 - fused: True + fused: False loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null diff --git a/recipes/configs/llama3_3/70B_full.yaml b/recipes/configs/llama3_3/70B_full.yaml index a1c2b3b1b2..4880a89edc 100644 --- a/recipes/configs/llama3_3/70B_full.yaml +++ b/recipes/configs/llama3_3/70B_full.yaml @@ -18,10 +18,6 @@ output_dir: /tmp/torchtune/llama3_3_70B/full # /tmp may be deleted by your system. Change it to your preference. -tensor_parallel_dim: 2 -parallelize_plan: - _component_: torchtune.models.llama3.base_llama_tp_plan - # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index b3212bfd00..ea7d6b1b58 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -12,6 +12,7 @@ from warnings import warn import torch +import torch.distributed as dist from omegaconf import DictConfig, ListConfig from torch import nn @@ -26,7 +27,6 @@ from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY -from torchtune.training._distributed import build_device_mesh from torchtune.training.activations import apply_selective_activation_checkpointing from torchtune.training.checkpointing._checkpoint_client import ( CheckpointClient, @@ -143,7 +143,12 @@ def __init__(self, cfg: DictConfig) -> None: self.world_size, self.rank = utils.get_world_size_and_rank() self._is_rank_zero = self.rank == 0 self.parallelize_plan = config.instantiate(cfg.get("parallelize_plan", None)) - self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", None) + self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1) + if self.world_size % self.tensor_parallel_dim != 0: + raise ValueError( + f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}" + ) + self.data_parallel_dim = self.world_size // self.tensor_parallel_dim # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint @@ -497,7 +502,7 @@ def _setup_model( utils.log_rank_zero( log, - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + "FSDP and TP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", ) init_start = time.perf_counter() @@ -526,20 +531,19 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - device_mesh = build_device_mesh( + self.device_mesh = dist.init_device_mesh( self._device.type, - data_parallel_dim=self.world_size // self.tensor_parallel_dim, - tensor_parallel_dim=self.tensor_parallel_dim, + mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim), + mesh_dim_names=("dp", "tp"), ) # Apply tensor parallelism to the model - if self.tensor_parallel_dim is not None and self.tensor_parallel_dim > 1: + if self.tensor_parallel_dim > 1: if self.parallelize_plan is None: raise ValueError("Parallelism plan need to be provided when tensor parallel is enabled.") - tp_mesh = device_mesh["tp"] - print(f"{tp_mesh.size()=}") + tp_mesh = self.device_mesh["tp"] # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel - # training.prepare_mha_for_tp(model, tp_mesh) + # model = training.prepare_mha_for_tp(model, tp_mesh) parallelize_module( model, tp_mesh, @@ -547,19 +551,20 @@ def _setup_model( ) # Apply Fully Sharded Data Parallelism to the model - fsdp_shard_conditions = [ - partial( - training.get_shard_conditions, - names_to_match=custom_sharded_layers, + if self.data_parallel_dim > 1: + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + dp_mesh=self.device_mesh["dp"], ) - ] - training.shard_model( - model=model, - shard_conditions=fsdp_shard_conditions, - cpu_offload=fsdp_cpu_offload, - reshard_after_forward=reshard_after_forward, - dp_mesh=device_mesh["dp"], - ) with training.set_default_dtype(self._dtype), self._device: for m in model.modules(): @@ -680,8 +685,10 @@ def _setup_data( raise RuntimeError("left_pad_sequence collator is only for inference.") collate_fn = _get_component_from_path(collate_fn) + dp_mesh = self.device_mesh["dp"] + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() sampler = DistributedSampler( - ds, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, seed=0 + ds, num_replicas=dp_degree, rank=dp_rank, shuffle=shuffle, seed=0 ) dataloader = DataLoader( dataset=ds, @@ -748,7 +755,6 @@ def train(self) -> None: and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() - utils.batch_to_device(batch, self._device) # Calculate the number of unmasked tokens in the current batch diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 7322111575..4907a59d9f 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -25,14 +25,14 @@ set_optimizer_state_dict, StateDictOptions, ) -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import ShardingStrategy from torch.nn.modules.module import _IncompatibleKeys from torch.optim import Optimizer from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 from torchtune.modules import TransformerDecoder from torchtune.modules.attention import MultiHeadAttention -from torchtune.modules.model_fusion import DeepFusionModel +from torchtune.modules.model_fusion import DeepFusionModel, EarlyFusionModel from torchtune.modules.peft import get_adapter_state_dict from torchtune.utils import get_device, get_logger from torchtune.utils._logging import deprecated @@ -502,28 +502,6 @@ def get_shard_conditions( return False -def build_device_mesh( - device_type: str, - *, - data_parallel_dim: Optional[int] = None, - tensor_parallel_dim: Optional[int] = None, -) -> DeviceMesh: - if not any([data_parallel_dim, tensor_parallel_dim]): - raise ValueError( - "At least one of data_parallel_dim or tensor_parallel_dim must be specified" - ) - - valid_dim = [] - valid_dim_names = [] - if data_parallel_dim is not None and data_parallel_dim >= 0: - valid_dim.append(data_parallel_dim) - valid_dim_names.append("dp") - if tensor_parallel_dim is not None and tensor_parallel_dim >= 0: - valid_dim.append(tensor_parallel_dim) - valid_dim_names.append("tp") - - return init_device_mesh(device_type, valid_dim, mesh_dim_names=valid_dim_names) - def shard_model( model: TransformerDecoder, @@ -612,11 +590,11 @@ def prepare_mha_for_tp( >>> # num_kv_heads = 16 (32/2) >>> # embed_dim = 2048 (4096/2) """ - # Consider the case of Deep Fusion models - if isinstance(model, DeepFusionModel): - model = model.decoder + # Handle fusion models by extracting decoder + is_fusion_model = isinstance(model, (DeepFusionModel, EarlyFusionModel)) + decoder = model.decoder if is_fusion_model else model tp_size = tp_mesh.size() - for m in list(model.modules()): + for m in list(decoder.modules()): if isinstance(m, MultiHeadAttention): # Adjust attention module to use the local number of heads if m.num_heads % tp_size != 0: @@ -637,4 +615,7 @@ def prepare_mha_for_tp( m.num_heads = m.num_heads // tp_size m.num_kv_heads = m.num_kv_heads // tp_size m.embed_dim = m.embed_dim // tp_size + + if is_fusion_model: + model.decoder = decoder return model From 0bd8e5651fbefd14b8afaa52648ff64ae54fa9a5 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 3 Feb 2025 21:39:05 -0800 Subject: [PATCH 03/16] distributed training working --- .../llama3/70B_generation_distributed.yaml | 14 +++---- recipes/dev/generate_v2_distributed.py | 3 +- recipes/full_finetune_distributed.py | 42 +++++++++---------- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/recipes/configs/llama3/70B_generation_distributed.yaml b/recipes/configs/llama3/70B_generation_distributed.yaml index 78c77ba263..ed7ebe3784 100644 --- a/recipes/configs/llama3/70B_generation_distributed.yaml +++ b/recipes/configs/llama3/70B_generation_distributed.yaml @@ -11,7 +11,7 @@ output_dir: ./ # Model arguments model: - _component_: torchtune.models.llama3.llama3_70b + _component_: torchtune.models.llama3.llama3_8b parallelize_plan: _component_: torchtune.models.llama3.base_llama_tp_plan @@ -19,17 +19,17 @@ parallelize_plan: # Transform arguments tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model + path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model prompt_template: null max_seq_len: 8192 # Checkpointer checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct - checkpoint_files: - filename_format: model-{}-of-{}.safetensors - max_filename: "00030" + _component_: torchtune.training.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/dev/generate_v2_distributed.py b/recipes/dev/generate_v2_distributed.py index 48a147bd15..d1f68345e0 100644 --- a/recipes/dev/generate_v2_distributed.py +++ b/recipes/dev/generate_v2_distributed.py @@ -105,13 +105,14 @@ def setup(self, cfg: DictConfig) -> None: tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape) # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell - training.prepare_mha_for_tp(model, tp_device_mesh) + model = training.prepare_mha_for_tp(model, tp_device_mesh) parallelize_module( model, tp_device_mesh, parallelize_plan=config.instantiate(cfg.parallelize_plan), ) + with training.set_default_dtype(self._dtype), self._device: for m in model.modules(): # RoPE is not covered in state dict diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index ea7d6b1b58..bb42315576 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -512,25 +512,6 @@ def _setup_model( if self._compile: training.compile_model(model, verbose=self._is_rank_zero) - # We currently have two versions of activation checkpointing in this recipe - # for testing and BC purposes. ``enable_activation_checkpointing`` controls - # the older version of AC and this behavior is unchanged - # ac_mode and ac_option together control selective AC. This is only enabled - # when these are set AND ``enable_activation_checkpointing`` is set to False - # We'll clean this up as soon as testing of AC is complete - if (not enable_activation_checkpointing) and (ac_mode is not None): - apply_selective_activation_checkpointing( - model, - ac_mode, - ac_option, - ) - - # original activation checkpointing (full) - flip the condition above - if enable_activation_checkpointing and ac_mode is None: - training.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} - ) - self.device_mesh = dist.init_device_mesh( self._device.type, mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim), @@ -543,13 +524,13 @@ def _setup_model( raise ValueError("Parallelism plan need to be provided when tensor parallel is enabled.") tp_mesh = self.device_mesh["tp"] # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel - # model = training.prepare_mha_for_tp(model, tp_mesh) + model = training.prepare_mha_for_tp(model, tp_mesh) parallelize_module( model, tp_mesh, parallelize_plan=self.parallelize_plan, ) - + # Apply Fully Sharded Data Parallelism to the model if self.data_parallel_dim > 1: fsdp_shard_conditions = [ @@ -566,6 +547,25 @@ def _setup_model( dp_mesh=self.device_mesh["dp"], ) + # We currently have two versions of activation checkpointing in this recipe + # for testing and BC purposes. ``enable_activation_checkpointing`` controls + # the older version of AC and this behavior is unchanged + # ac_mode and ac_option together control selective AC. This is only enabled + # when these are set AND ``enable_activation_checkpointing`` is set to False + # We'll clean this up as soon as testing of AC is complete + if (not enable_activation_checkpointing) and (ac_mode is not None): + apply_selective_activation_checkpointing( + model, + ac_mode, + ac_option, + ) + + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + with training.set_default_dtype(self._dtype), self._device: for m in model.modules(): # RoPE is not covered in state dict From ff0dee3865fa3a2d8c42feef3b19e27f3097834b Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 3 Feb 2025 22:16:37 -0800 Subject: [PATCH 04/16] added configs --- recipes/configs/llama3/70B_full.yaml | 7 ++++++- recipes/configs/llama3/8B_full.yaml | 4 ---- recipes/configs/llama3_1/70B_full.yaml | 5 +++++ recipes/configs/llama3_3/70B_full.yaml | 5 +++++ recipes/full_finetune_distributed.py | 1 + 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index ee9c914ce8..618e8e0a1b 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -19,6 +19,11 @@ output_dir: /tmp/torchtune/llama3_70B/full # /tmp may be deleted by your system. Change it to your preference. +# Parallelism +tensor_parallel_dim: 1 +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer @@ -54,7 +59,7 @@ epochs: 1 optimizer: _component_: torch.optim.AdamW lr: 2e-5 - fused: True + fused: False loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 870d6a4907..ca996e3b6f 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -20,10 +20,6 @@ output_dir: /tmp/torchtune/llama3_8B/full # /tmp may be deleted by your system. Change it to your preference. -tensor_parallel_dim: 4 -parallelize_plan: - _component_: torchtune.models.llama3.base_llama_tp_plan - # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index 0c4c7fce7f..6fb85fe280 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -18,6 +18,11 @@ output_dir: /tmp/torchtune/llama3_1_70B/full # /tmp may be deleted by your system. Change it to your preference. +# Parallelism +tensor_parallel_dim: 1 +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer diff --git a/recipes/configs/llama3_3/70B_full.yaml b/recipes/configs/llama3_3/70B_full.yaml index 4880a89edc..223a24eaa2 100644 --- a/recipes/configs/llama3_3/70B_full.yaml +++ b/recipes/configs/llama3_3/70B_full.yaml @@ -18,6 +18,11 @@ output_dir: /tmp/torchtune/llama3_3_70B/full # /tmp may be deleted by your system. Change it to your preference. +# Parallelism +tensor_parallel_dim: 1 +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index bb42315576..3e0e4953e1 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -519,6 +519,7 @@ def _setup_model( ) # Apply tensor parallelism to the model + print(f"{self.tensor_parallel_dim=}") if self.tensor_parallel_dim > 1: if self.parallelize_plan is None: raise ValueError("Parallelism plan need to be provided when tensor parallel is enabled.") From be8ab60c4756220583610dcad87ad57197823165 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 3 Feb 2025 22:17:22 -0800 Subject: [PATCH 05/16] formatting --- recipes/dev/generate_v2_distributed.py | 1 - recipes/full_finetune_distributed.py | 6 ++++-- torchtune/training/_distributed.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/recipes/dev/generate_v2_distributed.py b/recipes/dev/generate_v2_distributed.py index d1f68345e0..8cbf139672 100644 --- a/recipes/dev/generate_v2_distributed.py +++ b/recipes/dev/generate_v2_distributed.py @@ -112,7 +112,6 @@ def setup(self, cfg: DictConfig) -> None: parallelize_plan=config.instantiate(cfg.parallelize_plan), ) - with training.set_default_dtype(self._dtype), self._device: for m in model.modules(): # RoPE is not covered in state dict diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 3e0e4953e1..6cd8a50316 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -522,7 +522,9 @@ def _setup_model( print(f"{self.tensor_parallel_dim=}") if self.tensor_parallel_dim > 1: if self.parallelize_plan is None: - raise ValueError("Parallelism plan need to be provided when tensor parallel is enabled.") + raise ValueError( + "Parallelism plan need to be provided when tensor parallel is enabled." + ) tp_mesh = self.device_mesh["tp"] # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel model = training.prepare_mha_for_tp(model, tp_mesh) @@ -531,7 +533,7 @@ def _setup_model( tp_mesh, parallelize_plan=self.parallelize_plan, ) - + # Apply Fully Sharded Data Parallelism to the model if self.data_parallel_dim > 1: fsdp_shard_conditions = [ diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 4907a59d9f..744def096a 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -615,7 +615,7 @@ def prepare_mha_for_tp( m.num_heads = m.num_heads // tp_size m.num_kv_heads = m.num_kv_heads // tp_size m.embed_dim = m.embed_dim // tp_size - + if is_fusion_model: model.decoder = decoder return model From 0793d0819505822fa0743d1f6d7b376790fd9911 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 3 Feb 2025 22:26:38 -0800 Subject: [PATCH 06/16] formatting --- .../llama3/70B_generation_distributed.yaml | 14 ++++---- recipes/configs/llama3/8B_full.yaml | 2 +- recipes/configs/llama3_1/70B_full.yaml | 4 +-- recipes/configs/llama3_3/70B_full.yaml | 2 +- recipes/full_finetune_distributed.py | 32 +++++++++---------- 5 files changed, 27 insertions(+), 27 deletions(-) diff --git a/recipes/configs/llama3/70B_generation_distributed.yaml b/recipes/configs/llama3/70B_generation_distributed.yaml index ed7ebe3784..78c77ba263 100644 --- a/recipes/configs/llama3/70B_generation_distributed.yaml +++ b/recipes/configs/llama3/70B_generation_distributed.yaml @@ -11,7 +11,7 @@ output_dir: ./ # Model arguments model: - _component_: torchtune.models.llama3.llama3_8b + _component_: torchtune.models.llama3.llama3_70b parallelize_plan: _component_: torchtune.models.llama3.base_llama_tp_plan @@ -19,17 +19,17 @@ parallelize_plan: # Transform arguments tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model prompt_template: null max_seq_len: 8192 # Checkpointer checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ - checkpoint_files: [ - consolidated.00.pth - ] + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00030" recipe_checkpoint: null output_dir: ${output_dir} model_type: LLAMA3 diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index ca996e3b6f..4c23331171 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -55,7 +55,7 @@ epochs: 1 optimizer: _component_: torch.optim.AdamW lr: 2e-5 - fused: False + fused: True loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index 6fb85fe280..7a5ab42528 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -19,7 +19,7 @@ output_dir: /tmp/torchtune/llama3_1_70B/full # /tmp may be deleted by your system. Change it to your preference. # Parallelism -tensor_parallel_dim: 1 +tensor_parallel_dim: 4 parallelize_plan: _component_: torchtune.models.llama3.base_llama_tp_plan @@ -60,7 +60,7 @@ optimizer: lr: 2e-5 # Note: highly recommended to use fused=True optimizer flag # with CPU offload for faster optimizer step. - fused: True + fused: False loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss diff --git a/recipes/configs/llama3_3/70B_full.yaml b/recipes/configs/llama3_3/70B_full.yaml index 223a24eaa2..4d9df225ce 100644 --- a/recipes/configs/llama3_3/70B_full.yaml +++ b/recipes/configs/llama3_3/70B_full.yaml @@ -60,7 +60,7 @@ optimizer: lr: 2e-5 # Note: highly recommended to use fused=True optimizer flag # with CPU offload for faster optimizer step. - fused: True + fused: False loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 6cd8a50316..b9b44c109d 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -534,22 +534,6 @@ def _setup_model( parallelize_plan=self.parallelize_plan, ) - # Apply Fully Sharded Data Parallelism to the model - if self.data_parallel_dim > 1: - fsdp_shard_conditions = [ - partial( - training.get_shard_conditions, - names_to_match=custom_sharded_layers, - ) - ] - training.shard_model( - model=model, - shard_conditions=fsdp_shard_conditions, - cpu_offload=fsdp_cpu_offload, - reshard_after_forward=reshard_after_forward, - dp_mesh=self.device_mesh["dp"], - ) - # We currently have two versions of activation checkpointing in this recipe # for testing and BC purposes. ``enable_activation_checkpointing`` controls # the older version of AC and this behavior is unchanged @@ -569,6 +553,22 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + # Apply Fully Sharded Data Parallelism to the model + if self.data_parallel_dim > 1: + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + dp_mesh=self.device_mesh["dp"], + ) + with training.set_default_dtype(self._dtype), self._device: for m in model.modules(): # RoPE is not covered in state dict From b441391f1512d78a046389a6f718ccffb6cd107a Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 3 Feb 2025 22:37:35 -0800 Subject: [PATCH 07/16] misc --- recipes/configs/llama3_1/70B_full.yaml | 2 +- recipes/full_finetune_distributed.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index 7a5ab42528..146c51da79 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -19,7 +19,7 @@ output_dir: /tmp/torchtune/llama3_1_70B/full # /tmp may be deleted by your system. Change it to your preference. # Parallelism -tensor_parallel_dim: 4 +tensor_parallel_dim: 1 parallelize_plan: _component_: torchtune.models.llama3.base_llama_tp_plan diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index b9b44c109d..31fbc7e6a1 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -519,7 +519,6 @@ def _setup_model( ) # Apply tensor parallelism to the model - print(f"{self.tensor_parallel_dim=}") if self.tensor_parallel_dim > 1: if self.parallelize_plan is None: raise ValueError( From 244a77d96058e2a53a3c5d1bee00a87675ae2e74 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Wed, 5 Feb 2025 15:42:31 -0800 Subject: [PATCH 08/16] misc --- recipes/full_finetune_distributed.py | 25 ++++++++++++------------- torchtune/training/_distributed.py | 4 +--- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 31fbc7e6a1..8773697589 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -144,6 +144,10 @@ def __init__(self, cfg: DictConfig) -> None: self._is_rank_zero = self.rank == 0 self.parallelize_plan = config.instantiate(cfg.get("parallelize_plan", None)) self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1) + if self.tensor_parallel_dim > 1 and self.parallelize_plan is None: + raise ValueError( + "Parallelism plan need to be provided when tensor parallel is enabled." + ) if self.world_size % self.tensor_parallel_dim != 0: raise ValueError( f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}" @@ -502,7 +506,7 @@ def _setup_model( utils.log_rank_zero( log, - "FSDP and TP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + "Distributed training(FSDP and TP) is enabled. Instantiating model and loading checkpoint on Rank 0 ...", ) init_start = time.perf_counter() @@ -512,24 +516,21 @@ def _setup_model( if self._compile: training.compile_model(model, verbose=self._is_rank_zero) - self.device_mesh = dist.init_device_mesh( + device_mesh = dist.init_device_mesh( self._device.type, mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim), mesh_dim_names=("dp", "tp"), ) + self.dp_size = device_mesh["dp"].size() + self.dp_rank = device_mesh["dp"].get_local_rank() # Apply tensor parallelism to the model if self.tensor_parallel_dim > 1: - if self.parallelize_plan is None: - raise ValueError( - "Parallelism plan need to be provided when tensor parallel is enabled." - ) - tp_mesh = self.device_mesh["tp"] # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel - model = training.prepare_mha_for_tp(model, tp_mesh) + model = training.prepare_mha_for_tp(model, device_mesh["tp"]) parallelize_module( model, - tp_mesh, + device_mesh["tp"], parallelize_plan=self.parallelize_plan, ) @@ -565,7 +566,7 @@ def _setup_model( shard_conditions=fsdp_shard_conditions, cpu_offload=fsdp_cpu_offload, reshard_after_forward=reshard_after_forward, - dp_mesh=self.device_mesh["dp"], + dp_mesh=device_mesh["dp"], ) with training.set_default_dtype(self._dtype), self._device: @@ -687,10 +688,8 @@ def _setup_data( raise RuntimeError("left_pad_sequence collator is only for inference.") collate_fn = _get_component_from_path(collate_fn) - dp_mesh = self.device_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() sampler = DistributedSampler( - ds, num_replicas=dp_degree, rank=dp_rank, shuffle=shuffle, seed=0 + ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle, seed=0 ) dataloader = DataLoader( dataset=ds, diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 744def096a..8a4c83961c 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -534,11 +534,9 @@ def shard_model( Raises: ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. """ - fsdp_kwargs = {"reshard_after_forward": reshard_after_forward} + fsdp_kwargs = {"reshard_after_forward": reshard_after_forward, "mesh": dp_mesh} if cpu_offload: fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() - if dp_mesh: - fsdp_kwargs["mesh"] = dp_mesh # Shard the model with FSDP, iterating in reverse to start with # lowest-level modules first From da815c5bb75ad39afb4435c4d22ef321f8e867dd Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Wed, 5 Feb 2025 17:35:37 -0800 Subject: [PATCH 09/16] misc --- recipes/configs/llama3/70B_full.yaml | 2 +- recipes/configs/llama3/8B_full.yaml | 5 +++++ tests/recipes/test_full_finetune_distributed.py | 13 +++++++++---- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index 618e8e0a1b..1fd6151a6b 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -20,7 +20,7 @@ output_dir: /tmp/torchtune/llama3_70B/full # /tmp may be deleted by your system. Change it to your preference. # Parallelism -tensor_parallel_dim: 1 +tensor_parallel_dim: 2 parallelize_plan: _component_: torchtune.models.llama3.base_llama_tp_plan diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 4c23331171..402fe41ead 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -20,6 +20,11 @@ output_dir: /tmp/torchtune/llama3_8B/full # /tmp may be deleted by your system. Change it to your preference. +# Parallelism +tensor_parallel_dim: 1 +parallelize_plan: + _component_: torchtune.models.llama3.base_llama_tp_plan + # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 4cdc42d96b..a9f6ee2587 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -65,11 +65,12 @@ def _fetch_expected_loss_values_single_rank(self, model_type): @pytest.mark.integration_test @pytest.mark.parametrize( - "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd, tensor_parallel_dim", [ - ("llama2/7B_full", "llama2", "hf", 1, 4, False), - ("llama3/8B_full", "llama3", "tune", 1, 4, False), - ("llama3/8B_full", "llama3", "tune", 4, 1, True), + # ("llama2/7B_full", "llama2", "hf", 1, 4, False, 1), + ("llama3/8B_full", "llama3", "tune", 1, 4, False, 1), + ("llama3/8B_full", "llama3", "tune", 4, 1, True, 1), + ("llama3/8B_full", "llama3", "tune", 4, 1, True, 2), ], ) @gpu_test(gpu_count=2) @@ -81,6 +82,7 @@ def test_loss( model_type, ckpt_type, optim_in_bwd, + tensor_parallel_dim, tmpdir, monkeypatch, ): @@ -107,6 +109,7 @@ def test_loss( checkpointer.model_type={model_type.upper()} \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ + tensor_parallel_dim={tensor_parallel_dim} \ metric_logger.filename={log_file} \ """.split() model_config = MODEL_TEST_CONFIGS[model_type] @@ -124,6 +127,8 @@ def test_loss( monkeypatch.setattr(sys, "argv", cmd) runpy.run_path(TUNE_PATH, run_name="__main__") loss_values = get_loss_values_from_metric_logger(log_file) + print(f"{loss_values=}") + assert False expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type) torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 From 0d692e2593148b044bbd18d4c8ca375f3815459c Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Wed, 5 Feb 2025 21:13:03 -0800 Subject: [PATCH 10/16] add tests --- .../recipes/test_full_finetune_distributed.py | 88 ++++++++++++++++--- 1 file changed, 76 insertions(+), 12 deletions(-) diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index a9f6ee2587..8c10bdd5d7 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -67,7 +67,7 @@ def _fetch_expected_loss_values_single_rank(self, model_type): @pytest.mark.parametrize( "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd, tensor_parallel_dim", [ - # ("llama2/7B_full", "llama2", "hf", 1, 4, False, 1), + ("llama2/7B_full", "llama2", "hf", 1, 4, False, 1), ("llama3/8B_full", "llama3", "tune", 1, 4, False, 1), ("llama3/8B_full", "llama3", "tune", 4, 1, True, 1), ("llama3/8B_full", "llama3", "tune", 4, 1, True, 2), @@ -109,7 +109,7 @@ def test_loss( checkpointer.model_type={model_type.upper()} \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ - tensor_parallel_dim={tensor_parallel_dim} \ + tnesor_parallel_dim={tensor_parallel_dim} \ metric_logger.filename={log_file} \ """.split() model_config = MODEL_TEST_CONFIGS[model_type] @@ -127,8 +127,6 @@ def test_loss( monkeypatch.setattr(sys, "argv", cmd) runpy.run_path(TUNE_PATH, run_name="__main__") loss_values = get_loss_values_from_metric_logger(log_file) - print(f"{loss_values=}") - assert False expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type) torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 @@ -136,15 +134,13 @@ def test_loss( @pytest.mark.integration_test @pytest.mark.parametrize( - "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd, tensor_parallel_dim", [ - ("llama2/7B_full", "llama2", "hf", 1, 4, False), - ("llama3/8B_full", "llama3", "tune", 1, 4, False), - ("llama3/8B_full", "llama3", "tune", 4, 1, True), + ("llama3/8B_full", "llama3", "tune", 4, 1, True, 2), ], ) - @gpu_test(gpu_count=1) - def test_loss_single_rank( + @gpu_test(gpu_count=4) + def test_loss_2d_parallel( self, micro_batch_size, gradient_accumulation_steps, @@ -152,6 +148,7 @@ def test_loss_single_rank( model_type, ckpt_type, optim_in_bwd, + tensor_parallel_dim, tmpdir, monkeypatch, ): @@ -166,7 +163,7 @@ def test_loss_single_rank( write_hf_ckpt_config(ckpt_dir) cmd = f""" - tune run --nnodes 1 --nproc_per_node 1 full_finetune_distributed \ + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ --config {config} \ batch_size={micro_batch_size} \ gradient_accumulation_steps={gradient_accumulation_steps} \ @@ -178,6 +175,7 @@ def test_loss_single_rank( checkpointer.model_type={model_type.upper()} \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ + tnesor_parallel_dim={tensor_parallel_dim} \ metric_logger.filename={log_file} \ """.split() model_config = MODEL_TEST_CONFIGS[model_type] @@ -187,17 +185,83 @@ def test_loss_single_rank( # should be the same. if not optim_in_bwd: cmd.append("clip_grad_norm=100") + # Test that gradient clipping works with CPU offload + cmd.append("fsdp_cpu_offload=True") else: cmd.append("optimizer_in_bwd=True") monkeypatch.setattr(sys, "argv", cmd) runpy.run_path(TUNE_PATH, run_name="__main__") loss_values = get_loss_values_from_metric_logger(log_file) - expected_loss_values = self._fetch_expected_loss_values_single_rank(model_type) + expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type) torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) + # @pytest.mark.integration_test + # @pytest.mark.parametrize( + # "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", + # [ + # ("llama2/7B_full", "llama2", "hf", 1, 4, False), + # ("llama3/8B_full", "llama3", "tune", 1, 4, False), + # ("llama3/8B_full", "llama3", "tune", 4, 1, True), + # ], + # ) + # @gpu_test(gpu_count=1) + # def test_loss_single_rank( + # self, + # micro_batch_size, + # gradient_accumulation_steps, + # config, + # model_type, + # ckpt_type, + # optim_in_bwd, + # tmpdir, + # monkeypatch, + # ): + # ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + # ckpt = model_type + "_" + ckpt_type + # ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + # tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + # ckpt_dir = ckpt_path.parent + # log_file = gen_log_file_name(tmpdir) + + # # Config file needed for model conversion. + # write_hf_ckpt_config(ckpt_dir) + + # cmd = f""" + # tune run --nnodes 1 --nproc_per_node 1 full_finetune_distributed \ + # --config {config} \ + # batch_size={micro_batch_size} \ + # gradient_accumulation_steps={gradient_accumulation_steps} \ + # output_dir={tmpdir} \ + # checkpointer._component_={ckpt_component} \ + # checkpointer.checkpoint_dir='{ckpt_dir}' \ + # checkpointer.checkpoint_files=[{ckpt_path}]\ + # checkpointer.output_dir={tmpdir} \ + # checkpointer.model_type={model_type.upper()} \ + # tokenizer.path='{tokenizer_path}' \ + # tokenizer.prompt_template=null \ + # metric_logger.filename={log_file} \ + # """.split() + # model_config = MODEL_TEST_CONFIGS[model_type] + # cmd = cmd + self._get_test_config_overrides() + model_config + # # "optimizer_in_bwd=True" would free gradient info before clip_grad, causing + # # wrong grad_norm, so we only test one of them each time. But loss values + # # should be the same. + # if not optim_in_bwd: + # cmd.append("clip_grad_norm=100") + # else: + # cmd.append("optimizer_in_bwd=True") + + # monkeypatch.setattr(sys, "argv", cmd) + # runpy.run_path(TUNE_PATH, run_name="__main__") + # loss_values = get_loss_values_from_metric_logger(log_file) + # expected_loss_values = self._fetch_expected_loss_values_single_rank(model_type) + # torch.testing.assert_close( + # loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + # ) + @pytest.mark.integration_test @pytest.mark.parametrize( "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", From 1a749bf06d4312cf74bb2184a69cd293baada211 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Wed, 5 Feb 2025 21:17:02 -0800 Subject: [PATCH 11/16] update test --- .../recipes/test_full_finetune_distributed.py | 126 +++++++++--------- 1 file changed, 63 insertions(+), 63 deletions(-) diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 8c10bdd5d7..64344cbd98 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -198,69 +198,69 @@ def test_loss_2d_parallel( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) - # @pytest.mark.integration_test - # @pytest.mark.parametrize( - # "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", - # [ - # ("llama2/7B_full", "llama2", "hf", 1, 4, False), - # ("llama3/8B_full", "llama3", "tune", 1, 4, False), - # ("llama3/8B_full", "llama3", "tune", 4, 1, True), - # ], - # ) - # @gpu_test(gpu_count=1) - # def test_loss_single_rank( - # self, - # micro_batch_size, - # gradient_accumulation_steps, - # config, - # model_type, - # ckpt_type, - # optim_in_bwd, - # tmpdir, - # monkeypatch, - # ): - # ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] - # ckpt = model_type + "_" + ckpt_type - # ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) - # tokenizer_path = Path(TOKENIZER_PATHS[model_type]) - # ckpt_dir = ckpt_path.parent - # log_file = gen_log_file_name(tmpdir) - - # # Config file needed for model conversion. - # write_hf_ckpt_config(ckpt_dir) - - # cmd = f""" - # tune run --nnodes 1 --nproc_per_node 1 full_finetune_distributed \ - # --config {config} \ - # batch_size={micro_batch_size} \ - # gradient_accumulation_steps={gradient_accumulation_steps} \ - # output_dir={tmpdir} \ - # checkpointer._component_={ckpt_component} \ - # checkpointer.checkpoint_dir='{ckpt_dir}' \ - # checkpointer.checkpoint_files=[{ckpt_path}]\ - # checkpointer.output_dir={tmpdir} \ - # checkpointer.model_type={model_type.upper()} \ - # tokenizer.path='{tokenizer_path}' \ - # tokenizer.prompt_template=null \ - # metric_logger.filename={log_file} \ - # """.split() - # model_config = MODEL_TEST_CONFIGS[model_type] - # cmd = cmd + self._get_test_config_overrides() + model_config - # # "optimizer_in_bwd=True" would free gradient info before clip_grad, causing - # # wrong grad_norm, so we only test one of them each time. But loss values - # # should be the same. - # if not optim_in_bwd: - # cmd.append("clip_grad_norm=100") - # else: - # cmd.append("optimizer_in_bwd=True") - - # monkeypatch.setattr(sys, "argv", cmd) - # runpy.run_path(TUNE_PATH, run_name="__main__") - # loss_values = get_loss_values_from_metric_logger(log_file) - # expected_loss_values = self._fetch_expected_loss_values_single_rank(model_type) - # torch.testing.assert_close( - # loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 - # ) + @pytest.mark.integration_test + @pytest.mark.parametrize( + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", + [ + ("llama2/7B_full", "llama2", "hf", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 4, 1, True), + ], + ) + @gpu_test(gpu_count=1) + def test_loss_single_rank( + self, + micro_batch_size, + gradient_accumulation_steps, + config, + model_type, + ckpt_type, + optim_in_bwd, + tmpdir, + monkeypatch, + ): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + + cmd = f""" + tune run --nnodes 1 --nproc_per_node 1 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger.filename={log_file} \ + """.split() + model_config = MODEL_TEST_CONFIGS[model_type] + cmd = cmd + self._get_test_config_overrides() + model_config + # "optimizer_in_bwd=True" would free gradient info before clip_grad, causing + # wrong grad_norm, so we only test one of them each time. But loss values + # should be the same. + if not optim_in_bwd: + cmd.append("clip_grad_norm=100") + else: + cmd.append("optimizer_in_bwd=True") + + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_expected_loss_values_single_rank(model_type) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + ) @pytest.mark.integration_test @pytest.mark.parametrize( From 9387a9a8baa77f6d3c64fc9086d6eba0de1f58a9 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Wed, 5 Feb 2025 22:12:37 -0800 Subject: [PATCH 12/16] formatting --- recipes/configs/llama3/8B_full.yaml | 5 ----- .../recipes/test_full_finetune_distributed.py | 18 +++++++++--------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 402fe41ead..4c23331171 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -20,11 +20,6 @@ output_dir: /tmp/torchtune/llama3_8B/full # /tmp may be deleted by your system. Change it to your preference. -# Parallelism -tensor_parallel_dim: 1 -parallelize_plan: - _component_: torchtune.models.llama3.base_llama_tp_plan - # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 64344cbd98..78bc6a9c36 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -65,12 +65,11 @@ def _fetch_expected_loss_values_single_rank(self, model_type): @pytest.mark.integration_test @pytest.mark.parametrize( - "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd, tensor_parallel_dim", + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", [ - ("llama2/7B_full", "llama2", "hf", 1, 4, False, 1), - ("llama3/8B_full", "llama3", "tune", 1, 4, False, 1), - ("llama3/8B_full", "llama3", "tune", 4, 1, True, 1), - ("llama3/8B_full", "llama3", "tune", 4, 1, True, 2), + ("llama2/7B_full", "llama2", "hf", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 4, 1, True), ], ) @gpu_test(gpu_count=2) @@ -82,7 +81,6 @@ def test_loss( model_type, ckpt_type, optim_in_bwd, - tensor_parallel_dim, tmpdir, monkeypatch, ): @@ -109,7 +107,6 @@ def test_loss( checkpointer.model_type={model_type.upper()} \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ - tnesor_parallel_dim={tensor_parallel_dim} \ metric_logger.filename={log_file} \ """.split() model_config = MODEL_TEST_CONFIGS[model_type] @@ -158,12 +155,13 @@ def test_loss_2d_parallel( tokenizer_path = Path(TOKENIZER_PATHS[model_type]) ckpt_dir = ckpt_path.parent log_file = gen_log_file_name(tmpdir) + parallelize_plan = "torchtune.models.llama3.base_llama_tp_plan" # Config file needed for model conversion. write_hf_ckpt_config(ckpt_dir) cmd = f""" - tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ + tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ --config {config} \ batch_size={micro_batch_size} \ gradient_accumulation_steps={gradient_accumulation_steps} \ @@ -175,7 +173,8 @@ def test_loss_2d_parallel( checkpointer.model_type={model_type.upper()} \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ - tnesor_parallel_dim={tensor_parallel_dim} \ + tensor_parallel_dim={tensor_parallel_dim} \ + parallelize_plan._component_={parallelize_plan} \ metric_logger.filename={log_file} \ """.split() model_config = MODEL_TEST_CONFIGS[model_type] @@ -194,6 +193,7 @@ def test_loss_2d_parallel( runpy.run_path(TUNE_PATH, run_name="__main__") loss_values = get_loss_values_from_metric_logger(log_file) expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type) + print(f"{loss_values=}") torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) From fcc2295878ee87d54b7f638fb6f84c80c2cd5751 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Wed, 5 Feb 2025 23:32:46 -0800 Subject: [PATCH 13/16] misc --- recipes/configs/llama3/70B_full.yaml | 2 +- tests/recipes/test_full_finetune_distributed.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index 1fd6151a6b..618e8e0a1b 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -20,7 +20,7 @@ output_dir: /tmp/torchtune/llama3_70B/full # /tmp may be deleted by your system. Change it to your preference. # Parallelism -tensor_parallel_dim: 2 +tensor_parallel_dim: 1 parallelize_plan: _component_: torchtune.models.llama3.base_llama_tp_plan diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 78bc6a9c36..904e188a4a 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -193,7 +193,6 @@ def test_loss_2d_parallel( runpy.run_path(TUNE_PATH, run_name="__main__") loss_values = get_loss_values_from_metric_logger(log_file) expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type) - print(f"{loss_values=}") torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) From 046a38f97205502ed5616b8a5ca248caa0d44ca5 Mon Sep 17 00:00:00 2001 From: joecummings Date: Fri, 7 Feb 2025 15:49:58 -0800 Subject: [PATCH 14/16] Remove reference to FSDP and TP --- recipes/full_finetune_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 01a3248d0c..b827ecdea7 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -519,7 +519,7 @@ def _setup_model( utils.log_rank_zero( log, - "Distributed training(FSDP and TP) is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + "Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...", ) init_start = time.perf_counter() From 9d375fa0d7a4ee44a4154421eaebda3628454084 Mon Sep 17 00:00:00 2001 From: joecummings Date: Fri, 7 Feb 2025 15:51:17 -0800 Subject: [PATCH 15/16] Consolidate distributed APIs --- recipes/full_finetune_distributed.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index b827ecdea7..fa3fa7dfec 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -12,11 +12,14 @@ from warnings import warn import torch -import torch.distributed as dist from omegaconf import DictConfig, ListConfig from torch import nn -from torch.distributed import destroy_process_group, init_process_group +from torch.distributed import ( + destroy_process_group, + init_device_mesh, + init_process_group, +) from torch.distributed.tensor.parallel import parallelize_module from torch.optim import Optimizer @@ -529,7 +532,7 @@ def _setup_model( if self._compile: training.compile_model(model, verbose=self._is_rank_zero) - device_mesh = dist.init_device_mesh( + device_mesh = init_device_mesh( self._device.type, mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim), mesh_dim_names=("dp", "tp"), From 5ab0933ce79ac6ee2092e00992e7c2a0b9834222 Mon Sep 17 00:00:00 2001 From: joecummings Date: Fri, 7 Feb 2025 18:49:12 -0800 Subject: [PATCH 16/16] Fix the case where data is not sharded --- recipes/full_finetune_distributed.py | 7 +++++-- tests/recipes/test_full_finetune_distributed.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index fa3fa7dfec..a968ebf0ef 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -20,8 +20,8 @@ init_device_mesh, init_process_group, ) +from torch.distributed._tensor import DTensor from torch.distributed.tensor.parallel import parallelize_module - from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils @@ -832,7 +832,10 @@ def train(self) -> None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ).full_tensor() + ) + # If sharded, collect the DTensor here + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 904e188a4a..33012b1da5 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -134,6 +134,7 @@ def test_loss( "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd, tensor_parallel_dim", [ ("llama3/8B_full", "llama3", "tune", 4, 1, True, 2), + ("llama3/8B_full", "llama3", "tune", 4, 1, True, 4), ], ) @gpu_test(gpu_count=4)