Skip to content

Commit dfe72c4

Browse files
[float8] Bug fix: do not override requires_grad=False when enable_float8_all_gather=True (#1873)
1 parent 711fa08 commit dfe72c4

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def init_multi_module(self) -> nn.Module:
6767
return module
6868

6969
def init_transformer(
70-
self, weight_tying: bool, dtype: Optional[torch.dtype] = None
70+
self,
71+
weight_tying: bool,
72+
dtype: Optional[torch.dtype] = None,
73+
requires_grad: bool = True,
7174
) -> nn.Module:
7275
torch.manual_seed(42)
7376
args = ModelArgs(
@@ -81,6 +84,13 @@ def init_transformer(
8184
module = Transformer(args).cuda()
8285
if dtype is not None:
8386
module = module.to(dtype=dtype)
87+
88+
# if requires_grad=False, just set requires_grad to False
89+
# in the first layer to ensure we still train some params.
90+
if requires_grad is False:
91+
for param in module.layers[0].parameters():
92+
param.requires_grad = requires_grad
93+
8494
self.broadcast_module(module)
8595
return module
8696

@@ -107,6 +117,7 @@ def test_transformer_parity(self):
107117
],
108118
"compile_transformer_block": [False, True],
109119
"dtype": [torch.float32, torch.bfloat16],
120+
"requires_grad": [True, False],
110121
},
111122
self._test_transformer_parity,
112123
)
@@ -117,6 +128,7 @@ def _test_transformer_parity(
117128
precompute: bool,
118129
scaling_type_weight: ScalingType,
119130
compile_transformer_block: bool,
131+
requires_grad: bool,
120132
dtype: Optional[torch.dtype] = None,
121133
):
122134
if not enable_fsdp_float8_all_gather and precompute:
@@ -127,7 +139,10 @@ def _test_transformer_parity(
127139
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
128140
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
129141
weight_tying = not enable_fsdp_float8_all_gather
130-
module = self.init_transformer(weight_tying=weight_tying, dtype=dtype)
142+
module = self.init_transformer(
143+
weight_tying=weight_tying, dtype=dtype, requires_grad=requires_grad
144+
)
145+
131146
ref_module = copy.deepcopy(module)
132147
float8_linear_config1 = Float8LinearConfig(
133148
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),

torchao/float8/float8_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,8 @@ def from_float(
416416
new_mod.weight,
417417
new_mod.linear_mm_config,
418418
new_mod.config.cast_config_weight.target_dtype,
419-
)
419+
),
420+
requires_grad=new_mod.weight.requires_grad,
420421
)
421422

422423
return new_mod

torchao/testing/float8/fsdp2_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def check_parity_no_mp(
2222
precompute: bool = False,
2323
compile_transformer_block: bool = False,
2424
):
25+
# check that requires_grad matches ref module
26+
for ref_param, fsdp_param in zip(ref_model.parameters(), fsdp_model.parameters()):
27+
test_cls.assertEqual(
28+
ref_param.requires_grad,
29+
fsdp_param.requires_grad,
30+
msg=f"ref_param.requires_grad: {ref_param.requires_grad}, fsdp_param.requires_grad: {fsdp_param.requires_grad}",
31+
)
32+
2533
# TODO(before land): reorder args and make config not optional
2634
for iter_idx in range(10):
2735
losses: List[torch.Tensor] = []
@@ -31,8 +39,9 @@ def check_parity_no_mp(
3139
losses[-1].backward()
3240
if model is ref_model:
3341
for param in model.parameters():
34-
dist.all_reduce(param.grad)
35-
param.grad.div_(dist.get_world_size())
42+
if param.requires_grad:
43+
dist.all_reduce(param.grad)
44+
param.grad.div_(dist.get_world_size())
3645

3746
optim.step()
3847
if (

0 commit comments

Comments
 (0)