13
13
# limitations under the License.
14
14
15
15
from abc import ABC , abstractmethod
16
- from typing import Optional
16
+ from collections import defaultdict
17
+ from typing import List , Optional , Tuple
17
18
18
19
import torch
19
20
import torch .nn .utils .parametrize as P
@@ -48,10 +49,13 @@ class TransformFactory(RegistryMixin, ABC):
48
49
:param seed: random seed used to transform weight randomization
49
50
"""
50
51
52
+ transforms : List ["TransformBase" ]
53
+
51
54
def __init__ (self , name : str , scheme : TransformScheme , seed : Optional [int ] = None ):
52
55
self .name = name
53
56
self .scheme = scheme
54
57
self .generator = torch .Generator ()
58
+ self .transforms = list ()
55
59
if seed is not None :
56
60
self .generator .manual_seed (seed )
57
61
@@ -90,6 +94,8 @@ def apply_to_model(self, model: Module):
90
94
if is_target (name , module , arg .targets , arg .ignore ):
91
95
self ._apply_to_module (module , arg )
92
96
97
+ self ._update_tied_weights ()
98
+
93
99
def _apply_to_module (self , module : Module , args : TransformArgs ):
94
100
"""
95
101
Create transforms and apply them to the module
@@ -145,6 +151,28 @@ def output_hook(_, _input, output):
145
151
else :
146
152
raise NotImplementedError ()
147
153
154
+ def _update_tied_weights (self ):
155
+ """
156
+ Populate the `_dynamic_tied_weights_keys` attribute of transforms,
157
+ which is used by transformers to detect and remove shared pointers
158
+ during saving
159
+ """
160
+ # avoid issues with this method being called twice
161
+ for transform in self .transforms :
162
+ transform ._dynamic_tied_weights_keys = list ()
163
+
164
+ # map from data_ptrs to keys
165
+ ptr_to_keys : dict [int , List [Tuple [TransformBase , str ]]] = defaultdict (list )
166
+ for transform in self .transforms :
167
+ for name , param in transform .named_parameters (recurse = False ):
168
+ ptr_to_keys [param .data_ptr ()].append ((transform , name ))
169
+
170
+ # populate `_dynamic_tied_weights_keys` if there is more than one key
171
+ for shared_keys in ptr_to_keys .values ():
172
+ if len (shared_keys ) > 1 :
173
+ for transform , name in shared_keys :
174
+ transform ._dynamic_tied_weights_keys .append (name )
175
+
148
176
149
177
class TransformBase (Module , ABC ):
150
178
"""
@@ -153,6 +181,11 @@ class TransformBase(Module, ABC):
153
181
154
182
args : TransformArgs
155
183
weight : Parameter
184
+ _dynamic_tied_weights_keys : List [str ]
185
+
186
+ def __init__ (self ):
187
+ super ().__init__ ()
188
+ self ._dynamic_tied_weights_keys = list ()
156
189
157
190
@abstractmethod
158
191
def forward (self , value : Tensor ) -> Tensor :
0 commit comments