Skip to content

Commit 5ffe91e

Browse files
committed
Fact. embedding: improve interface
1 parent 8b1022e commit 5ffe91e

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

tltorch/factorized_layers/factorized_embedding.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ class FactorizedEmbedding(nn.Module):
1818
number of entries in the lookup table
1919
embedding_dim : int
2020
number of dimensions per entry
21-
auto_reshape : bool
21+
auto_tensorize : bool
2222
whether to use automatic reshaping for the embedding dimensions
23-
d : int or int tuple
23+
n_tensorized_modes : int or int tuple
2424
number of reshape dimensions for both embedding table dimension
2525
tensorized_num_embeddings : int tuple
2626
tensorized shape of the first embedding table dimension
@@ -34,8 +34,8 @@ class FactorizedEmbedding(nn.Module):
3434
def __init__(self,
3535
num_embeddings,
3636
embedding_dim,
37-
auto_reshape=True,
38-
d=3,
37+
auto_tensorize=True,
38+
n_tensorized_modes=3,
3939
tensorized_num_embeddings=None,
4040
tensorized_embedding_dim=None,
4141
factorization='blocktt',
@@ -45,14 +45,14 @@ def __init__(self,
4545
dtype=None):
4646
super().__init__()
4747

48-
if auto_reshape:
48+
if auto_tensorize:
4949

5050
if tensorized_num_embeddings is not None and tensorized_embedding_dim is not None:
5151
raise ValueError(
52-
"Either use auto_reshape or specify tensorized_num_embeddings and tensorized_embedding_dim."
52+
"Either use auto_tensorize or specify tensorized_num_embeddings and tensorized_embedding_dim."
5353
)
5454

55-
tensorized_num_embeddings, tensorized_embedding_dim = get_tensorized_shape(in_features=num_embeddings, out_features=embedding_dim, order=d, min_dim=2, verbose=False)
55+
tensorized_num_embeddings, tensorized_embedding_dim = get_tensorized_shape(in_features=num_embeddings, out_features=embedding_dim, order=n_tensorized_modes, min_dim=2, verbose=False)
5656

5757
else:
5858
#check that dimensions match factorization
@@ -121,8 +121,9 @@ def from_embedding(cls,
121121
embedding_layer,
122122
rank=8,
123123
factorization='blocktt',
124+
n_tensorized_modes=2,
124125
decompose_weights=True,
125-
auto_reshape=True,
126+
auto_tensorize=True,
126127
decomposition_kwargs=dict(),
127128
**kwargs):
128129
"""
@@ -137,7 +138,7 @@ def from_embedding(cls,
137138
tensor type
138139
decompose_weights: bool
139140
whether to decompose weights and use for initialization
140-
auto_reshape: bool
141+
auto_tensorize: bool
141142
if True, automatically reshape dimensions for TensorizedTensor
142143
decomposition_kwargs: dict
143144
specify kwargs for the decomposition
@@ -146,8 +147,9 @@ def from_embedding(cls,
146147

147148
instance = cls(num_embeddings,
148149
embedding_dim,
149-
auto_reshape=auto_reshape,
150+
auto_tensorize=auto_tensorize,
150151
factorization=factorization,
152+
n_tensorized_modes=n_tensorized_modes,
151153
rank=rank,
152154
**kwargs)
153155

@@ -166,8 +168,9 @@ def from_embedding_list(cls,
166168
embedding_layer_list,
167169
rank=8,
168170
factorization='blocktt',
171+
n_tensorized_modes=2,
169172
decompose_weights=True,
170-
auto_reshape=True,
173+
auto_tensorize=True,
171174
decomposition_kwargs=dict(),
172175
**kwargs):
173176
"""
@@ -182,7 +185,7 @@ def from_embedding_list(cls,
182185
tensor decomposition to use
183186
decompose_weights: bool
184187
decompose weights and use for initialization
185-
auto_reshape: bool
188+
auto_tensorize: bool
186189
automatically reshape dimensions for TensorizedTensor
187190
decomposition_kwargs: dict
188191
specify kwargs for the decomposition
@@ -207,7 +210,8 @@ def from_embedding_list(cls,
207210

208211
instance = cls(num_embeddings,
209212
embedding_dim,
210-
auto_reshape=auto_reshape,
213+
n_tensorized_modes=n_tensorized_modes,
214+
auto_tensorize=auto_tensorize,
211215
factorization=factorization,
212216
rank=rank,
213217
n_layers=n_layers,

tltorch/factorized_layers/tests/test_factorized_embedding.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,13 @@
1111
@pytest.mark.parametrize('factorization', ['CP','Tucker', 'BlockTT'])
1212
@pytest.mark.parametrize('dims', [(256,16), (1000,32)])
1313
def test_FactorizedEmbedding(factorization,dims):
14-
15-
16-
17-
NUM_EMBEDDINGS,EMBEDDING_DIM=dims
18-
BATCH_SIZE = 3
14+
NUM_EMBEDDINGS, EMBEDDING_DIM = dims
1915

2016
#create factorized embedding
21-
factorized_embedding = FactorizedEmbedding(NUM_EMBEDDINGS,EMBEDDING_DIM,factorization=factorization)
17+
factorized_embedding = FactorizedEmbedding(NUM_EMBEDDINGS, EMBEDDING_DIM, factorization=factorization)
2218

2319
#make test embedding of same shape and same weight
24-
test_embedding = torch.nn.Embedding(factorized_embedding.weight.shape[0],factorized_embedding.weight.shape[1])
20+
test_embedding = torch.nn.Embedding(factorized_embedding.weight.shape[0], factorized_embedding.weight.shape[1])
2521
test_embedding.weight.data.copy_(factorized_embedding.weight.to_matrix().detach())
2622

2723
#create batch and test using all entries (shuffled since entries may not be sorted)

0 commit comments

Comments
 (0)