Skip to content

Commit d7f369c

Browse files
committed
memory-optimized ff
1 parent a647682 commit d7f369c

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed

src/diffusers/models/attention.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,10 +1215,20 @@ def __init__(
12151215
bias: bool = True,
12161216
):
12171217
super().__init__()
1218+
12181219
if inner_dim is None:
12191220
inner_dim = int(dim * mult)
12201221
dim_out = dim_out if dim_out is not None else dim
12211222

1223+
self._dim = dim
1224+
self._dim_out = dim_out
1225+
self._mult = mult
1226+
self._dropout = dropout
1227+
self._activation_fn = activation_fn
1228+
self._final_dropout = final_dropout
1229+
self._inner_dim = inner_dim
1230+
self._bias = bias
1231+
12221232
if activation_fn == "gelu":
12231233
act_fn = GELU(dim, inner_dim, bias=bias)
12241234
if activation_fn == "gelu-approximate":

src/diffusers/models/memory_utils.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
19+
from ..utils import logging
20+
from .activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU
21+
from .attention import FeedForward
22+
23+
24+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25+
26+
27+
class _MemoryOptimizedFeedForward(torch.nn.Module):
28+
r"""
29+
See [`~models.attention.FeedForward`] parameter documentation. This class is a copy of the FeedForward class. The
30+
only difference is that this module is optimized for memory.
31+
32+
This method achieves memory savings by applying the ideas of tensor-parallelism sequentially. Input projection
33+
layers are split column-wise and output projection layers are split row-wise. This allows for the computation of
34+
the feedforward pass to occur without ever materializing the full intermediate tensor. Typically, the intermediate
35+
tensor takes 4x-8x more memory than the input tensor. This method reduces that with a small performance tradeoff.
36+
"""
37+
38+
def __init__(
39+
self,
40+
dim: int,
41+
dim_out: Optional[int] = None,
42+
mult: int = 4,
43+
dropout: float = 0.0,
44+
activation_fn: str = "geglu",
45+
final_dropout: bool = False,
46+
inner_dim: Optional[int] = None,
47+
bias: bool = True,
48+
num_splits: int = 4,
49+
) -> None:
50+
super().__init__()
51+
52+
if inner_dim is None:
53+
inner_dim = int(dim * mult)
54+
55+
dim_out = dim_out if dim_out is not None else dim
56+
57+
dim_split = inner_dim // num_splits
58+
if inner_dim % dim_split != 0:
59+
raise ValueError(f"inner_dim must be divisible by {mult=}, or {num_splits=} if provided.")
60+
61+
self._dim = dim
62+
self._dim_out = dim_out
63+
self._mult = mult
64+
self._dropout = dropout
65+
self._activation_fn = activation_fn
66+
self._final_dropout = final_dropout
67+
self._inner_dim = inner_dim
68+
self._bias = bias
69+
self._num_splits = num_splits
70+
71+
def get_activation_fn(dim_: int, inner_dim_: int):
72+
if activation_fn == "gelu":
73+
act_fn = GELU(dim_, inner_dim_, bias=bias)
74+
if activation_fn == "gelu-approximate":
75+
act_fn = GELU(dim_, inner_dim_, approximate="tanh", bias=bias)
76+
elif activation_fn == "geglu":
77+
act_fn = GEGLU(dim_, inner_dim_, bias=bias)
78+
elif activation_fn == "geglu-approximate":
79+
act_fn = ApproximateGELU(dim_, inner_dim_, bias=bias)
80+
elif activation_fn == "swiglu":
81+
act_fn = SwiGLU(dim_, inner_dim_, bias=bias)
82+
elif activation_fn == "linear-silu":
83+
act_fn = LinearActivation(dim_, inner_dim_, bias=bias, activation="silu")
84+
return act_fn
85+
86+
# Split column-wise
87+
self.proj_in = torch.nn.ModuleList([get_activation_fn(dim, dim_split) for _ in range(inner_dim // dim_split)])
88+
89+
self.dropout = torch.nn.Dropout(dropout)
90+
91+
# Split row-wise
92+
self.proj_out = torch.nn.ModuleList(
93+
[torch.nn.Linear(dim_split, dim_out, bias=False) for _ in range(inner_dim // dim_split)]
94+
)
95+
96+
self.bias = None
97+
if bias:
98+
self.bias = torch.nn.Parameter(torch.zeros(dim_out))
99+
100+
self.final_dropout = None
101+
if final_dropout:
102+
self.final_dropout = torch.nn.Dropout(dropout)
103+
104+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
105+
# Output tensor for "all_reduce" operation
106+
output = hidden_states.new_zeros(hidden_states.shape)
107+
108+
# Apply feedforward pass sequentially since this is intended for memory optimization on a single GPU
109+
for proj_in, proj_out in zip(self.proj_in, self.proj_out):
110+
out = proj_in(hidden_states)
111+
out = self.dropout(out)
112+
out = proj_out(out)
113+
# Perform "all_reduce"
114+
output += out
115+
116+
if self.bias is not None:
117+
output += self.bias
118+
if self.final_dropout is not None:
119+
output = self.final_dropout(output)
120+
121+
return output
122+
123+
124+
def apply_memory_optimized_feedforward(module: torch.nn.Module, num_splits: Optional[int] = None) -> torch.nn.Module:
125+
module_dict = dict(module.named_modules())
126+
127+
for name, submodule in module_dict.items():
128+
if not isinstance(submodule, FeedForward):
129+
continue
130+
131+
logger.debug(f"Applying memory optimized feedforward to layer '{name}'")
132+
state_dict = submodule.state_dict()
133+
num_splits = submodule._mult if num_splits is None else num_splits
134+
135+
# remap net.0.proj.weight
136+
net_0_proj = state_dict.pop("net.0.proj.weight")
137+
net_0_proj = net_0_proj.chunk(num_splits, dim=0)
138+
for i in range(num_splits):
139+
state_dict[f"proj_in.{i}.proj.weight"] = net_0_proj[i]
140+
141+
# remap net.0.proj.bias
142+
if "net.0.proj.bias" in state_dict:
143+
net_0_proj_bias = state_dict.pop("net.0.proj.bias")
144+
net_0_proj_bias = net_0_proj_bias.chunk(num_splits, dim=0)
145+
for i in range(num_splits):
146+
state_dict[f"proj_in.{i}.proj.bias"] = net_0_proj_bias[i]
147+
148+
# remap net.2.weight
149+
net_2_weight = state_dict.pop("net.2.weight")
150+
net_2_weight = net_2_weight.chunk(num_splits, dim=1)
151+
for i in range(num_splits):
152+
state_dict[f"proj_out.{i}.weight"] = net_2_weight[i]
153+
154+
# remap net.2.bias
155+
if "net.2.bias" in state_dict:
156+
net_2_bias = state_dict.pop("net.2.bias")
157+
state_dict["bias"] = net_2_bias
158+
159+
with torch.device("meta"):
160+
new_ff = _MemoryOptimizedFeedForward(
161+
dim=submodule._dim,
162+
dim_out=submodule._dim_out,
163+
mult=submodule._mult,
164+
dropout=submodule._dropout,
165+
activation_fn=submodule._activation_fn,
166+
final_dropout=submodule._final_dropout,
167+
inner_dim=submodule._inner_dim,
168+
bias=submodule._bias,
169+
num_splits=num_splits,
170+
)
171+
172+
new_ff.load_state_dict(state_dict, strict=True, assign=True)
173+
174+
parent_module_name, _, submodule_name = name.rpartition(".")
175+
parent_module = module_dict[parent_module_name]
176+
setattr(parent_module, submodule_name, new_ff)
177+
178+
return module

0 commit comments

Comments
 (0)