Skip to content

Commit f5626da

Browse files
committed
FIX tests: input dtype
1 parent 6b93e33 commit f5626da

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

tltorch/factorized_layers/tests/test_factorized_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_FactorizedLinear(factorization):
1616
in_shape = (3, 3)
1717
out_features = 16
1818
out_shape = (4, 4)
19-
data = tl.tensor(rng.random_sample((batch_size, in_features)))
19+
data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32)
2020

2121
# Creat from a tensor factorization
2222
tensor = TensorizedTensor.new((out_shape, in_shape), rank='same', factorization=factorization)

tltorch/factorized_layers/tests/test_tensor_contraction_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_tcl():
1010
batch_size = 2
1111
in_shape = (4, 5, 6)
1212
out_shape = (2, 3, 5)
13-
data = tl.tensor(rng.random_sample((batch_size, ) + in_shape))
13+
data = tl.tensor(rng.random_sample((batch_size, ) + in_shape), dtype=tl.float32)
1414

1515
expected_shape = (batch_size, ) + out_shape
1616
tcl = TCL(input_shape=in_shape, rank=out_shape, bias=False)

tltorch/factorized_layers/tests/test_trl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def test_trl(factorization, true_rank, rank):
7676
tol = 0.08
7777

7878
# Generate a random tensor
79-
samples = tl.tensor(rng.normal(size=(batch_size, *input_shape), loc=0, scale=1))
80-
true_bias = tl.tensor(rng.uniform(size=output_shape))
79+
samples = tl.tensor(rng.normal(size=(batch_size, *input_shape), loc=0, scale=1), dtype=tl.float32)
80+
true_bias = tl.tensor(rng.uniform(size=output_shape), dtype=tl.float32)
8181

8282
with torch.no_grad():
8383
true_weight = FactorizedTensor.new(shape=input_shape+output_shape,
@@ -130,7 +130,7 @@ def test_TuckerTRL(order, project_input, learn_pool):
130130
# fix the random seed for reproducibility and create random input
131131
random_state = 12345
132132
rng = tl.check_random_state(random_state)
133-
data = tl.tensor(rng.random_sample((batch_size, in_features) + (spatial_size, )*order))
133+
data = tl.tensor(rng.random_sample((batch_size, in_features) + (spatial_size, )*order), dtype=tl.float32)
134134

135135
# Build a simple net with avg-pool, flatten + fully-connected
136136
if order == 2:
@@ -182,7 +182,7 @@ def test_TRL_from_linear(factorization, bias):
182182
# fix the random seed for reproducibility and create random input
183183
random_state = 12345
184184
rng = tl.check_random_state(random_state)
185-
data = tl.tensor(rng.random_sample((batch_size, in_features)))
185+
data = tl.tensor(rng.random_sample((batch_size, in_features)), dtype=tl.float32)
186186
fc = nn.Linear(in_features, out_features, bias=bias)
187187
res_fc = fc(tl.copy(data))
188188
trl = TRL((in_features, ), (out_features, ), rank=10, bias=bias, factorization=factorization)

tltorch/functional/tests/test_factorized_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_linear_tensor_dot_tucker(factorization, factorized_linear):
2121
rank = 3
2222
batch_size = 2
2323

24-
tensor = tl.randn((batch_size, in_dim))
24+
tensor = tl.randn((batch_size, in_dim), dtype=tl.float32)
2525
fact_weight = TensorizedTensor.new((out_shape, in_shape), rank=rank,
2626
factorization=factorization)
2727
fact_weight.normal_()

0 commit comments

Comments
 (0)