Skip to content

Commit 616edb0

Browse files
committed
Remove warnings
1 parent d27d58f commit 616edb0

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

tltorch/factorized_tensors/core.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import nn
77
import numpy as np
88

9+
import torch
910

1011
# Author: Jean Kossaifi
1112
# License: BSD 3 clause
@@ -368,12 +369,14 @@ def normal_(self, mean=0, std=1):
368369

369370
def __repr__(self):
370371
return f'{self.__class__.__name__}(shape={self.shape}, rank={self.rank})'
371-
372-
def __torch_function__(self, func, types, args=(), kwargs=None):
372+
373+
@classmethod
374+
def __torch_function__(cls, func, types, args=(), kwargs=None):
373375
if kwargs is None:
374376
kwargs = {}
375377

376378
args = [t.to_tensor() if hasattr(t, 'to_tensor') else t for t in args]
379+
# return super().__torch_function__(func, types, args, kwargs)
377380
return func(*args, **kwargs)
378381

379382
@property
@@ -549,7 +552,8 @@ def __repr__(self):
549552
msg += f'rank={self.rank})'
550553
return msg
551554

552-
def __torch_function__(self, func, types, args=(), kwargs=None):
555+
@classmethod
556+
def __torch_function__(cls, func, types, args=(), kwargs=None):
553557
if kwargs is None:
554558
kwargs = {}
555559

tltorch/factorized_tensors/tests/test_factorizations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_FactorizedTensor(factorization):
3939
res = fact_tensor[idx]
4040
if not torch.is_tensor(res):
4141
res = res.to_tensor()
42-
testing.assert_allclose(reconstruction[idx], res)
42+
testing.assert_close(reconstruction[idx], res)
4343

4444

4545
@pytest.mark.parametrize('factorization', ['BlockTT', 'CP']) #['CP', 'Tucker', 'BlockTT'])
@@ -85,7 +85,7 @@ def test_TensorizedMatrix(factorization, batch_size):
8585
res = fact_tensor[idx]
8686
if not torch.is_tensor(res):
8787
res = res.to_matrix()
88-
testing.assert_allclose(reconstruction[idx], res)
88+
testing.assert_close(reconstruction[idx], res)
8989

9090

9191
@pytest.mark.parametrize('factorization', ['CP', 'TT'])
@@ -104,7 +104,7 @@ def test_transduction(factorization):
104104

105105
indices = [slice(None)]*mode
106106
for i in range(new_dim):
107-
testing.assert_allclose(original_rec, rec[tuple(indices + [i])])
107+
testing.assert_close(original_rec, rec[tuple(indices + [i])])
108108

109109
@pytest.mark.parametrize('unsqueezed_init', ['average', 1.2])
110110
def test_tucker_init_unsqueezed_modes(unsqueezed_init):
@@ -122,7 +122,7 @@ def test_tucker_init_unsqueezed_modes(unsqueezed_init):
122122
coef = unsqueezed_init
123123

124124
for i in range(4):
125-
testing.assert_allclose(rec[:, i], mat*coef)
125+
testing.assert_close(rec[:, i], mat*coef)
126126

127127

128128
@pytest.mark.parametrize('factorization', ['ComplexCP', 'ComplexTucker', 'ComplexTT', 'ComplexDense'])

0 commit comments

Comments
 (0)