Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

output_dir: /tmp/torchtune/llama3_70B/full # /tmp may be deleted by your system. Change it to your preference.

# Parallelism
tensor_parallel_dim: 1
parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
Expand Down Expand Up @@ -54,7 +59,7 @@ epochs: 1
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
fused: True
fused: False

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
7 changes: 6 additions & 1 deletion recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

output_dir: /tmp/torchtune/llama3_1_70B/full # /tmp may be deleted by your system. Change it to your preference.

# Parallelism
tensor_parallel_dim: 1
parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
Expand Down Expand Up @@ -55,7 +60,7 @@ optimizer:
lr: 2e-5
# Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True
fused: False

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
7 changes: 6 additions & 1 deletion recipes/configs/llama3_3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

output_dir: /tmp/torchtune/llama3_3_70B/full # /tmp may be deleted by your system. Change it to your preference.

# Parallelism
tensor_parallel_dim: 1
parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
Expand Down Expand Up @@ -55,7 +60,7 @@ optimizer:
lr: 2e-5
# Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True
fused: False

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
2 changes: 1 addition & 1 deletion recipes/dev/generate_v2_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def setup(self, cfg: DictConfig) -> None:
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape)

# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
training.prepare_mha_for_tp(model, tp_device_mesh)
model = training.prepare_mha_for_tp(model, tp_device_mesh)
parallelize_module(
model,
tp_device_mesh,
Expand Down
93 changes: 64 additions & 29 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import destroy_process_group, init_process_group

from torch.distributed import (
destroy_process_group,
init_device_mesh,
init_process_group,
)
from torch.distributed._tensor import DTensor
from torch.distributed.tensor.parallel import parallelize_module
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
Expand Down Expand Up @@ -136,14 +141,26 @@ def __init__(self, cfg: DictConfig) -> None:
or self._enable_async_checkpointing,
)
init_process_group(self.distributed_backend)
_, rank = utils.get_world_size_and_rank()
self._is_rank_zero = rank == 0

# Initialize distributed variables
self.world_size, self.rank = utils.get_world_size_and_rank()
self._is_rank_zero = self.rank == 0
self.parallelize_plan = config.instantiate(cfg.get("parallelize_plan", None))
self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
if self.tensor_parallel_dim > 1 and self.parallelize_plan is None:
raise ValueError(
"Parallelism plan need to be provided when tensor parallel is enabled."
)
if self.world_size % self.tensor_parallel_dim != 0:
raise ValueError(
f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
)
self.data_parallel_dim = self.world_size // self.tensor_parallel_dim

# Logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)

if self._log_peak_memory_stats and device_type != "cuda":
log.info(
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
Expand Down Expand Up @@ -505,7 +522,7 @@ def _setup_model(

utils.log_rank_zero(
log,
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
"Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
)
init_start = time.perf_counter()

Expand All @@ -515,6 +532,24 @@ def _setup_model(
if self._compile:
training.compile_model(model, verbose=self._is_rank_zero)

device_mesh = init_device_mesh(
self._device.type,
mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim),
mesh_dim_names=("dp", "tp"),
)
self.dp_size = device_mesh["dp"].size()
self.dp_rank = device_mesh["dp"].get_local_rank()

# Apply tensor parallelism to the model
if self.tensor_parallel_dim > 1:
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
model = training.prepare_mha_for_tp(model, device_mesh["tp"])
parallelize_module(
model,
device_mesh["tp"],
parallelize_plan=self.parallelize_plan,
)

# We currently have two versions of activation checkpointing in this recipe
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
# the older version of AC and this behavior is unchanged
Expand All @@ -534,19 +569,21 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
# Apply Fully Sharded Data Parallelism to the model
if self.data_parallel_dim > 1:
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)
]
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
dp_mesh=device_mesh["dp"],
)
]
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)

with training.set_default_dtype(self._dtype), self._device:
for m in model.modules():
Expand Down Expand Up @@ -651,8 +688,6 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, self._tokenizer)
Expand All @@ -670,7 +705,7 @@ def _setup_data(
collate_fn = _get_component_from_path(collate_fn)

sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle, seed=0
)
dataloader = DataLoader(
dataset=ds,
Expand Down Expand Up @@ -700,8 +735,6 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
Expand All @@ -721,7 +754,7 @@ def train(self) -> None:
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)

pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
Expand All @@ -739,7 +772,6 @@ def train(self) -> None:
and self._device.type == "cuda"
):
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)

# Calculate the number of unmasked tokens in the current batch
Expand Down Expand Up @@ -782,7 +814,7 @@ def train(self) -> None:
torch.distributed.all_reduce(running_loss)

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (world_size / num_tokens)
current_loss = current_loss * (self.world_size / num_tokens)

current_loss.backward()

Expand All @@ -795,12 +827,15 @@ def train(self) -> None:
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
training.scale_grads(self._model, self.world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
).full_tensor()
)
# If sharded, collect the DTensor here
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

Expand Down Expand Up @@ -833,7 +868,7 @@ def train(self) -> None:
),
),
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * world_size),
/ (time_per_step * self.world_size),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
69 changes: 69 additions & 0 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,75 @@ def test_loss(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd, tensor_parallel_dim",
[
("llama3/8B_full", "llama3", "tune", 4, 1, True, 2),
("llama3/8B_full", "llama3", "tune", 4, 1, True, 4),
],
)
@gpu_test(gpu_count=4)
def test_loss_2d_parallel(
self,
micro_batch_size,
gradient_accumulation_steps,
config,
model_type,
ckpt_type,
optim_in_bwd,
tensor_parallel_dim,
tmpdir,
monkeypatch,
):
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
ckpt = model_type + "_" + ckpt_type
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)
parallelize_plan = "torchtune.models.llama3.base_llama_tp_plan"

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)

cmd = f"""
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
tensor_parallel_dim={tensor_parallel_dim} \
parallelize_plan._component_={parallelize_plan} \
metric_logger.filename={log_file} \
""".split()
model_config = MODEL_TEST_CONFIGS[model_type]
cmd = cmd + self._get_test_config_overrides() + model_config
# "optimizer_in_bwd=True" would free gradient info before clip_grad, causing
# wrong grad_norm, so we only test one of them each time. But loss values
# should be the same.
if not optim_in_bwd:
cmd.append("clip_grad_norm=100")
# Test that gradient clipping works with CPU offload
cmd.append("fsdp_cpu_offload=True")
else:
cmd.append("optimizer_in_bwd=True")

monkeypatch.setattr(sys, "argv", cmd)
runpy.run_path(TUNE_PATH, run_name="__main__")
loss_values = get_loss_values_from_metric_logger(log_file)
expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type)
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd",
Expand Down
19 changes: 12 additions & 7 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
from torchtune.modules import TransformerDecoder
from torchtune.modules.attention import MultiHeadAttention
from torchtune.modules.model_fusion import DeepFusionModel

from torchtune.modules.model_fusion import DeepFusionModel, EarlyFusionModel
from torchtune.modules.peft import get_adapter_state_dict
from torchtune.utils import get_device, get_logger
from torchtune.utils._logging import deprecated
Expand Down Expand Up @@ -523,6 +522,7 @@ def shard_model(
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
dp_mesh: Optional[DeviceMesh] = None,
) -> None:
"""
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
Expand All @@ -541,11 +541,13 @@ def shard_model(
reshard_after_forward (bool): Whether to reshard parameters and buffers after
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
dp_mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under mutliple parallelism.
Default to None.

Raises:
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward, "mesh": dp_mesh}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

Expand Down Expand Up @@ -599,11 +601,11 @@ def prepare_mha_for_tp(
>>> # num_kv_heads = 16 (32/2)
>>> # embed_dim = 2048 (4096/2)
"""
# Consider the case of Deep Fusion models
if isinstance(model, DeepFusionModel):
model = model.decoder
# Handle fusion models by extracting decoder
is_fusion_model = isinstance(model, (DeepFusionModel, EarlyFusionModel))
decoder = model.decoder if is_fusion_model else model
tp_size = tp_mesh.size()
for m in list(model.modules()):
for m in list(decoder.modules()):
if isinstance(m, MultiHeadAttention):
# Adjust attention module to use the local number of heads
if m.num_heads % tp_size != 0:
Expand All @@ -624,4 +626,7 @@ def prepare_mha_for_tp(
m.num_heads = m.num_heads // tp_size
m.num_kv_heads = m.num_kv_heads // tp_size
m.embed_dim = m.embed_dim // tp_size

if is_fusion_model:
model.decoder = decoder
return model