Skip to content

Commit 5f4d0e5

Browse files
committed
FIX embeddings: fix for non-contiguous inputs + typo
1 parent 923dc54 commit 5f4d0e5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tltorch/factorized_layers/factorized_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,20 @@ def forward(self, input, indices=0):
9797
#to handle case where input is not 1-D
9898
output_shape = (*input.shape, self.embedding_dim)
9999

100-
flatenned_input = input.view(-1)
100+
flattened_input = input.reshape(-1)
101101

102102
if self.n_layers == 1:
103103
if indices == 0:
104-
embeddings = self.weight[flatenned_input, :]
104+
embeddings = self.weight[flattened_input, :]
105105
else:
106-
embeddings = self.weight[indices, flatenned_input, :]
106+
embeddings = self.weight[indices, flattened_input, :]
107107

108108
#CPTensorized returns CPTensorized when indexing
109109
if self.factorization.lower() == 'cp':
110110
embeddings = embeddings.to_matrix()
111111

112112
#TuckerTensorized returns tensor not matrix,
113-
#and requires reshape not view for contiguous
113+
# and requires reshape not view for contiguous
114114
elif self.factorization.lower() == 'tucker':
115115
embeddings = embeddings.reshape(input.shape[0], -1)
116116

0 commit comments

Comments
 (0)