13
13
# limitations under the License.
14
14
15
15
from abc import ABC , abstractmethod
16
- from typing import Optional
16
+ from collections import defaultdict
17
+ from typing import List , Optional , Tuple
17
18
18
19
import torch
19
20
import torch .nn .utils .parametrize as P
@@ -47,10 +48,13 @@ class TransformFactory(RegistryMixin, ABC):
47
48
:param seed: random seed used to transform weight randomization
48
49
"""
49
50
51
+ transforms : List ["TransformBase" ]
52
+
50
53
def __init__ (self , name : str , scheme : TransformScheme , seed : Optional [int ] = None ):
51
54
self .name = name
52
55
self .scheme = scheme
53
56
self .generator = torch .Generator ()
57
+ self .transforms = list ()
54
58
if seed is not None :
55
59
self .generator .manual_seed (seed )
56
60
@@ -89,6 +93,8 @@ def apply_to_model(self, model: Module):
89
93
if is_target (name , module , arg .targets , arg .ignore ):
90
94
self ._apply_to_module (module , arg )
91
95
96
+ self ._update_tied_weights ()
97
+
92
98
def _apply_to_module (self , module : Module , args : TransformArgs ):
93
99
"""
94
100
Create transforms and apply them to the module
@@ -143,6 +149,28 @@ def output_hook(_, _input, output):
143
149
# (1) even in the `weight` cases, this submodule attachment is needed in order
144
150
# to support saving in the frozen state
145
151
152
+ def _update_tied_weights (self ):
153
+ """
154
+ Populate the `_dynamic_tied_weights_keys` attribute of transforms,
155
+ which is used by transformers to detect and remove shared pointers
156
+ during saving
157
+ """
158
+ # avoid issues with this method being called twice
159
+ for transform in self .transforms :
160
+ transform ._dynamic_tied_weights_keys = list ()
161
+
162
+ # map from data_ptrs to keys
163
+ ptr_to_keys : dict [int , List [Tuple [TransformBase , str ]]] = defaultdict (list )
164
+ for transform in self .transforms :
165
+ for name , param in transform .named_parameters (recurse = False ):
166
+ ptr_to_keys [param .data_ptr ()].append ((transform , name ))
167
+
168
+ # populate `_dynamic_tied_weights_keys` if there is more than one key
169
+ for shared_keys in ptr_to_keys .values ():
170
+ if len (shared_keys ) > 1 :
171
+ for transform , name in shared_keys :
172
+ transform ._dynamic_tied_weights_keys .append (name )
173
+
146
174
147
175
class TransformBase (Module , ABC ):
148
176
"""
@@ -151,6 +179,11 @@ class TransformBase(Module, ABC):
151
179
152
180
args : TransformArgs
153
181
weight : Parameter
182
+ _dynamic_tied_weights_keys : List [str ]
183
+
184
+ def __init__ (self ):
185
+ super ().__init__ ()
186
+ self ._dynamic_tied_weights_keys = list ()
154
187
155
188
@abstractmethod
156
189
def forward (self , value : Tensor ) -> Tensor :
0 commit comments