Skip to content

Commit 85d03de

Browse files
authored
[FSDP2] cast scale to float32 in precompute (#835)
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent b4d0768 commit 85d03de

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchao/float8/fsdp_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
5959
return
6060

6161
# inf-norm is equivalent to max(abs(w))
62-
max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float32) # Partial
62+
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
6363
amax_tensor = torch.stack(max_weights) # Partial
6464
# clamp is dispatched through DTensor
6565
# it will issue a single all-reduce
@@ -69,7 +69,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
6969
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
7070
local_scale_tensor = scale_tensor.to_local()
7171
for i, float8_linear in enumerate(float8_linears):
72-
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]
72+
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i].to(torch.float32)
7373

7474

7575
# FSDP pads its local tensor on dim-0. The subclass should be preserved such

0 commit comments

Comments
 (0)