@@ -18,9 +18,9 @@ class FactorizedEmbedding(nn.Module):
18
18
number of entries in the lookup table
19
19
embedding_dim : int
20
20
number of dimensions per entry
21
- auto_reshape : bool
21
+ auto_tensorize : bool
22
22
whether to use automatic reshaping for the embedding dimensions
23
- d : int or int tuple
23
+ n_tensorized_modes : int or int tuple
24
24
number of reshape dimensions for both embedding table dimension
25
25
tensorized_num_embeddings : int tuple
26
26
tensorized shape of the first embedding table dimension
@@ -34,8 +34,8 @@ class FactorizedEmbedding(nn.Module):
34
34
def __init__ (self ,
35
35
num_embeddings ,
36
36
embedding_dim ,
37
- auto_reshape = True ,
38
- d = 3 ,
37
+ auto_tensorize = True ,
38
+ n_tensorized_modes = 3 ,
39
39
tensorized_num_embeddings = None ,
40
40
tensorized_embedding_dim = None ,
41
41
factorization = 'blocktt' ,
@@ -45,14 +45,14 @@ def __init__(self,
45
45
dtype = None ):
46
46
super ().__init__ ()
47
47
48
- if auto_reshape :
48
+ if auto_tensorize :
49
49
50
50
if tensorized_num_embeddings is not None and tensorized_embedding_dim is not None :
51
51
raise ValueError (
52
- "Either use auto_reshape or specify tensorized_num_embeddings and tensorized_embedding_dim."
52
+ "Either use auto_tensorize or specify tensorized_num_embeddings and tensorized_embedding_dim."
53
53
)
54
54
55
- tensorized_num_embeddings , tensorized_embedding_dim = get_tensorized_shape (in_features = num_embeddings , out_features = embedding_dim , order = d , min_dim = 2 , verbose = False )
55
+ tensorized_num_embeddings , tensorized_embedding_dim = get_tensorized_shape (in_features = num_embeddings , out_features = embedding_dim , order = n_tensorized_modes , min_dim = 2 , verbose = False )
56
56
57
57
else :
58
58
#check that dimensions match factorization
@@ -121,8 +121,9 @@ def from_embedding(cls,
121
121
embedding_layer ,
122
122
rank = 8 ,
123
123
factorization = 'blocktt' ,
124
+ n_tensorized_modes = 2 ,
124
125
decompose_weights = True ,
125
- auto_reshape = True ,
126
+ auto_tensorize = True ,
126
127
decomposition_kwargs = dict (),
127
128
** kwargs ):
128
129
"""
@@ -137,7 +138,7 @@ def from_embedding(cls,
137
138
tensor type
138
139
decompose_weights: bool
139
140
whether to decompose weights and use for initialization
140
- auto_reshape : bool
141
+ auto_tensorize : bool
141
142
if True, automatically reshape dimensions for TensorizedTensor
142
143
decomposition_kwargs: dict
143
144
specify kwargs for the decomposition
@@ -146,8 +147,9 @@ def from_embedding(cls,
146
147
147
148
instance = cls (num_embeddings ,
148
149
embedding_dim ,
149
- auto_reshape = auto_reshape ,
150
+ auto_tensorize = auto_tensorize ,
150
151
factorization = factorization ,
152
+ n_tensorized_modes = n_tensorized_modes ,
151
153
rank = rank ,
152
154
** kwargs )
153
155
@@ -166,8 +168,9 @@ def from_embedding_list(cls,
166
168
embedding_layer_list ,
167
169
rank = 8 ,
168
170
factorization = 'blocktt' ,
171
+ n_tensorized_modes = 2 ,
169
172
decompose_weights = True ,
170
- auto_reshape = True ,
173
+ auto_tensorize = True ,
171
174
decomposition_kwargs = dict (),
172
175
** kwargs ):
173
176
"""
@@ -182,7 +185,7 @@ def from_embedding_list(cls,
182
185
tensor decomposition to use
183
186
decompose_weights: bool
184
187
decompose weights and use for initialization
185
- auto_reshape : bool
188
+ auto_tensorize : bool
186
189
automatically reshape dimensions for TensorizedTensor
187
190
decomposition_kwargs: dict
188
191
specify kwargs for the decomposition
@@ -207,7 +210,8 @@ def from_embedding_list(cls,
207
210
208
211
instance = cls (num_embeddings ,
209
212
embedding_dim ,
210
- auto_reshape = auto_reshape ,
213
+ n_tensorized_modes = n_tensorized_modes ,
214
+ auto_tensorize = auto_tensorize ,
211
215
factorization = factorization ,
212
216
rank = rank ,
213
217
n_layers = n_layers ,
0 commit comments