Skip to content

Commit ab692bb

Browse files
committed
Fixed Ts2Vec tests
1 parent a78b11c commit ab692bb

File tree

1 file changed

+15
-24
lines changed

1 file changed

+15
-24
lines changed
Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,20 @@
1+
"""TS2vec tests."""
2+
13
import numpy as np
4+
import pytest
25

36
from aeon.transformations.collection.contrastive_based._ts2vec import TS2Vec
47

58

6-
7-
def test_shape():
8-
expected_features = 200
9-
X = np.random.random(size=(10, 1, 100))
10-
transformer = TS2Vec(output_dim=expected_features)
11-
transformer.fit(X)
12-
X_trans = transformer.transform(X)
13-
np.testing.assert_equal(X_trans.shape, (len(X), expected_features))
14-
15-
def test_shape2():
16-
expected_features = 500
17-
X = np.random.random(size=(10, 1, 100))
18-
transformer = TS2Vec(output_dim=expected_features)
19-
transformer.fit(X)
20-
X_trans = transformer.transform(X)
21-
np.testing.assert_equal(X_trans.shape, (len(X), expected_features))
22-
23-
def test_shape3():
24-
expected_features = 200
25-
X = np.random.random(size=(10, 3, 100))
26-
transformer = TS2Vec(output_dim=expected_features)
27-
transformer.fit(X)
28-
X_trans = transformer.transform(X)
29-
np.testing.assert_equal(X_trans.shape, (len(X), expected_features))
9+
@pytest.mark.parametrize("expected_feature_size", [3, 5, 10])
10+
@pytest.mark.parametrize("n_series", [1, 2, 5])
11+
@pytest.mark.parametrize("n_channels", [1, 2, 3])
12+
@pytest.mark.parametrize("series_length", [3, 10, 20])
13+
def test_ts2vec_output_shapes(
14+
expected_feature_size, n_series, n_channels, series_length
15+
):
16+
"""Test the output shapes of the TS2Vec transformer."""
17+
X = np.random.random(size=(n_series, n_channels, series_length))
18+
transformer = TS2Vec(output_dim=expected_feature_size, device="cpu", n_epochs=2)
19+
X_t = transformer.fit_transform(X)
20+
assert X_t.shape == (n_series, expected_feature_size)

0 commit comments

Comments
 (0)