Skip to content

Commit dbb5367

Browse files
aparna-aketifacebook-github-bot
authored andcommitted
Add norm_grad_sample for RMSNorm (#755)
Summary: Pull Request resolved: #755 Llama3 model has RMSNorm and currently, we use functorch to support FGC for RMSNorm. This causes FSDP to rely on the root node for all_gather call of the RMSNorm layers. Adding norm_grad_sample method for RMSNorm to support layer-wise FSDP for this layer. Reviewed By: HuanyuZhang Differential Revision: D74334633 fbshipit-source-id: 399e98422066a8b9a5c5794538810c8f0d1c2de2
1 parent ce4605b commit dbb5367

File tree

3 files changed

+116
-0
lines changed

3 files changed

+116
-0
lines changed

opacus/grad_sample/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .gsm_exp_weights import GradSampleModuleExpandedWeights
2828
from .gsm_no_op import GradSampleModuleNoOp
2929
from .instance_norm import compute_instance_norm_grad_sample # noqa
30+
from .rms_norm import compute_rms_norm_grad_sample # noqa
3031
from .layer_norm import compute_layer_norm_grad_sample # noqa
3132
from .linear import compute_linear_grad_sample # noqa
3233
from .utils import (

opacus/grad_sample/rms_norm.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from typing import Dict, List
18+
19+
import torch
20+
import torch.nn as nn
21+
import torch.nn.functional as F
22+
from opacus.utils.tensor_utils import sum_over_all_but_batch_and_last_n
23+
24+
from .utils import register_grad_sampler
25+
26+
27+
@register_grad_sampler(nn.RMSNorm)
28+
def compute_rms_norm_grad_sample(
29+
layer: nn.RMSNorm,
30+
activations: List[torch.Tensor],
31+
backprops: torch.Tensor,
32+
) -> Dict[nn.Parameter, torch.Tensor]:
33+
"""
34+
Computes per sample gradients for RMSNorm
35+
36+
Args:
37+
layer: Layer
38+
activations: Activations
39+
backprops: Backpropagations
40+
"""
41+
activations = activations[0]
42+
ret = {}
43+
if layer.weight.requires_grad:
44+
ret[layer.weight] = sum_over_all_but_batch_and_last_n(
45+
F.rms_norm(activations, layer.normalized_shape, eps=layer.eps) * backprops,
46+
layer.weight.dim(),
47+
)
48+
return ret
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import hypothesis.strategies as st
17+
import torch
18+
import torch.nn as nn
19+
from hypothesis import given, settings
20+
21+
from .common import GradSampleHooks_test
22+
23+
24+
class RMSNorm_test(GradSampleHooks_test):
25+
@given(
26+
N=st.integers(1, 4),
27+
Z=st.integers(1, 4),
28+
H=st.integers(1, 3),
29+
W=st.integers(5, 10),
30+
input_dim=st.integers(2, 4),
31+
norm_dim=st.integers(1, 3),
32+
)
33+
@settings(deadline=60000)
34+
def test_input_norm(
35+
self, N: int, Z: int, W: int, H: int, input_dim: int, norm_dim: int
36+
):
37+
if norm_dim >= input_dim:
38+
return
39+
normalized_shape, x_shape = self.get_x_shape_and_norm_shape(
40+
H, N, W, Z, input_dim, norm_dim
41+
)
42+
43+
norm = nn.RMSNorm(normalized_shape, elementwise_affine=True)
44+
x = torch.randn(x_shape)
45+
self.run_test(x, norm, batch_first=True, ew_compatible=False)
46+
47+
@staticmethod
48+
def get_x_shape_and_norm_shape(H, N, W, Z, input_dim, norm_dim):
49+
if norm_dim == 1:
50+
normalized_shape = W
51+
if input_dim == 2:
52+
x_shape = [N, W]
53+
if input_dim == 3:
54+
x_shape = [N, Z, W]
55+
if input_dim == 4:
56+
x_shape = [N, Z, H, W]
57+
elif norm_dim == 2:
58+
if input_dim == 3:
59+
normalized_shape = [Z, W]
60+
x_shape = [N, Z, W]
61+
if input_dim == 4:
62+
normalized_shape = [H, W]
63+
x_shape = [N, Z, H, W]
64+
elif norm_dim == 3:
65+
normalized_shape = [Z, H, W]
66+
x_shape = [N, Z, H, W]
67+
return normalized_shape, x_shape

0 commit comments

Comments
 (0)