Skip to content

Commit 2843388

Browse files
authored
fix torchtitan + float8 + delayed + compile (#1334)
Summary: At some point torchtitan + delayed scaling + compile broke, fixing by switching to functional collectives for amax all-reduce. It would actually be great to add a local repro, will follow up offline on what could be missing in our current test coverage. Test Plan: ``` // torchtitan run which is fixed by this PR with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.scaling_type_input delayed --float8.scaling_type_weight delayed --float8.scaling_type_grad_output delayed // error message without this PR: https://gist.github.com/vkuzo/dbf54cf4027fd49bfb8095d518c618af ``` Reviewers: Subscribers: Tasks: Tags:
1 parent b2e42ff commit 2843388

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchao/float8/float8_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
import torch.distributed as dist
11+
from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce
1112

1213
from torchao.float8.config import Float8TypeConfig, ScalingGranularity
1314

@@ -109,7 +110,11 @@ def tensor_to_amax(
109110
# happen elsewhere.
110111
if reduce_amax and dist.is_initialized():
111112
pg = device_mesh.get_group() if device_mesh is not None else None
112-
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg)
113+
# dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg)
114+
group = list(range(dist.get_world_size())) if pg is None else pg
115+
amax = all_reduce(amax, "MAX", group)
116+
if isinstance(amax, AsyncCollectiveTensor):
117+
amax = amax.wait()
113118

114119
return amax
115120

0 commit comments

Comments
 (0)