Skip to content

Commit 0b44d4c

Browse files
authored
support AMP for DDP / single-device training (#1303)
There have been several requests / questions / attempts around this feature. See - #630 - #700 - #1278 - #1293 Hence this PR, even though AMP does not provide full support when `fully_shard` is not used. My local testing shows: - under TP, with AMP the throughput drops. - under PP, there is not a way to wrap the forward pass only @H-Huang ; if the forward+backward is wrapped, the program hangs. - It works fine under DDP / single-device training. This PR also adds logging on if mixed precision is enabled, and if so under what mechanism, `fully_shard` or AMP.
1 parent 2887250 commit 0b44d4c

File tree

4 files changed

+41
-4
lines changed

4 files changed

+41
-4
lines changed

torchtitan/config_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,10 @@ class Training:
220220

221221
mixed_precision_param: Literal["bfloat16", "float32"] = "bfloat16"
222222
"""
223-
torch dtype to use for parameters when applying mixed precision via FSDP.
224-
This feature only takes effect when data_parallel_shard_degree > 1
223+
torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast.
224+
This feature takes effect via fully_shard when data_parallel_shard_degree > 1 or
225+
context_parallel_degree > 1; it takes effect via torch.autocast when data_replicate_degree >= 1
226+
and no other parallelism is enabled, i.e. under DDP or single-device training.
225227
"""
226228

227229
mixed_precision_reduce: Literal["float32"] = "float32"

torchtitan/distributed/parallel_dims.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def cp_enabled(self):
124124
def dp_cp_enabled(self):
125125
return self.dp_enabled or self.cp_enabled
126126

127+
@property
128+
def fsdp_enabled(self):
129+
return self.dp_shard_enabled or self.cp_enabled
130+
127131
@property
128132
def tp_enabled(self):
129133
return self.tp > 1

torchtitan/distributed/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from torch.distributed.tensor import DTensor
1919
from torch.nn.attention import SDPBackend
2020

21+
from torchtitan.config_manager import TORCH_DTYPE_MAP
22+
from torchtitan.distributed.parallel_dims import ParallelDims
2123
from torchtitan.models.attention import ScaledDotProductAttention
2224
from torchtitan.tools.logging import logger
2325
from torchtitan.tools.utils import device_module, device_type
@@ -202,6 +204,29 @@ def context(cp_context: Generator[None, None, None] | None = None):
202204
return context
203205

204206

207+
def maybe_enable_amp(
208+
parallel_dims: ParallelDims, mixed_precision_param: str, device_type: torch.device
209+
) -> Generator[None, None, None]:
210+
if parallel_dims.fsdp_enabled:
211+
# FSDP handles mixed precision internally
212+
logger.info("Mixed precision training is handled by fully_shard")
213+
return contextlib.nullcontext()
214+
else:
215+
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
216+
logger.warning(
217+
"Mixed precision training with TP or PP is only supported when FSDP/HSDP/CP is enabled."
218+
)
219+
logger.info("Mixed precision training is disabled")
220+
return contextlib.nullcontext()
221+
else:
222+
# the following code will only be executed for DDP or single-device training
223+
logger.info("Mixed precision training is handled by AMP")
224+
return torch.autocast(
225+
device_type,
226+
dtype=TORCH_DTYPE_MAP[mixed_precision_param],
227+
)
228+
229+
205230
def init_distributed(job_config):
206231
def _warn_overwrite_env(env, val):
207232
if env in os.environ:

torchtitan/train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,11 @@ def __init__(self, job_config: JobConfig):
313313
parallel_dims.loss_parallel_enabled,
314314
parallelism_config.enable_compiled_autograd,
315315
)
316+
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
317+
parallel_dims,
318+
job_config.training.mixed_precision_param,
319+
device_type,
320+
)
316321

317322
logger.info(
318323
"Trainer is initialized with "
@@ -400,8 +405,9 @@ def forward_backward_step(
400405
# Non-PP forward / backward
401406
with self.train_context(optional_context_parallel_ctx):
402407
assert len(model_parts) == 1
403-
pred = model_parts[0](inputs)
404-
loss = self.loss_fn(pred, labels)
408+
with self.maybe_enable_amp:
409+
pred = model_parts[0](inputs)
410+
loss = self.loss_fn(pred, labels)
405411
# need to free to before bwd to avoid peaking memory
406412
del pred
407413
loss.backward()

0 commit comments

Comments
 (0)