File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
59
59
return
60
60
61
61
# inf-norm is equivalent to max(abs(w))
62
- max_weights = torch ._foreach_norm (weights , ord = math .inf , dtype = torch . float32 ) # Partial
62
+ max_weights = torch ._foreach_norm (weights , ord = math .inf ) # Partial
63
63
amax_tensor = torch .stack (max_weights ) # Partial
64
64
# clamp is dispatched through DTensor
65
65
# it will issue a single all-reduce
@@ -69,7 +69,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
69
69
scale_tensor = torch .clamp (scale_tensor , max = torch .finfo (torch .float16 ).max )
70
70
local_scale_tensor = scale_tensor .to_local ()
71
71
for i , float8_linear in enumerate (float8_linears ):
72
- float8_linear .weight ._local_tensor ._precomputed_scale = local_scale_tensor [i ]
72
+ float8_linear .weight ._local_tensor ._precomputed_scale = local_scale_tensor [i ]. to ( torch . float32 )
73
73
74
74
75
75
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
You can’t perform that action at this time.
0 commit comments