Skip to content

Commit 90416f5

Browse files
committed
TTTensor: fix indexing
Return self.__class__ to not break children inheriting
1 parent ba6a8ec commit 90416f5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tltorch/factorized_tensors/factorized_tensors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -427,12 +427,12 @@ def __getitem__(self, indices):
427427
# Select one dimension of one mode
428428
factor, next_factor, *factors = self.factors
429429
next_factor = tenalg.mode_dot(next_factor, factor[:, indices, :].squeeze(1), 0)
430-
return TTTensor([next_factor, *factors])
430+
return self.__class__([next_factor, *factors])
431431

432432
elif isinstance(indices, slice):
433433
mixing_factor, *factors = self.factors
434434
factors = [mixing_factor[:, indices], *factors]
435-
return TTTensor(factors)
435+
return self.__class__(factors)
436436

437437
else:
438438
factors = []
@@ -463,9 +463,9 @@ def __getitem__(self, indices):
463463
else:
464464
next_factor, *factors = self.factors[i+1:]
465465
factor = tenalg.mode_dot(next_factor, factor, 0)
466-
return TTTensor([factor, *factors])
466+
return self.__class__([factor, *factors])
467467
else:
468-
return TTTensor([*factors, factor, *self.factors[i+1:]])
468+
return self.__class__([*factors, factor, *self.factors[i+1:]])
469469

470470
def transduct(self, new_dim, mode=0, new_factor=None):
471471
"""Transduction adds a new dimension to the existing factorization

0 commit comments

Comments
 (0)