Skip to content

Commit 786bf3f

Browse files
authored
[BUG] Fixes a bug with SFAFast throwing an error when calling transform after fit (#2897)
* Bugfix for SFA_Fast with transform called after fit
1 parent f70439c commit 786bf3f

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

aeon/transformations/collection/dictionary_based/_sfa_fast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def _fit(self, X, y=None):
329329
self: object
330330
"""
331331
# with parallel_backend("loky", inner_max_num_threads=n_jobs):
332-
self._fit_transform(X, y, return_bag_of_words=False)
332+
self._fit_transform(X, y, return_bag_of_words=self.return_sparse)
333333
return self
334334

335335
def _transform(self, X, y=None):

aeon/transformations/collection/dictionary_based/tests/test_sfa.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import numpy as np
66
import pytest
77

8-
from aeon.transformations.collection.dictionary_based._sfa import SFA
8+
from aeon.datasets import load_unit_test
9+
from aeon.transformations.collection.dictionary_based import SFA, SFAFast
910

1011

1112
@pytest.mark.parametrize(
@@ -224,3 +225,27 @@ def test_typed_dict():
224225
word_list2 = p2.bag_to_string(p2.transform(X, y)[0][0])
225226

226227
assert word_list == word_list2
228+
229+
230+
def test_sfa_fast_transform_after_fit():
231+
"""Test transform called after fit returns the same result as fit_transform()."""
232+
X_train, y_train = load_unit_test(split="train")
233+
234+
# Fit, then transform
235+
sfa = SFAFast()
236+
sfa.fit(X_train, y_train)
237+
x = sfa.transform(X_train, y_train)
238+
239+
# Fit_transform, then transform
240+
sfa = SFAFast()
241+
sfa.fit_transform(X_train, y_train)
242+
y = sfa.transform(X_train, y_train)
243+
244+
# Assert that the two csr_matrix are equal
245+
assert (
246+
x.shape == y.shape
247+
and x.dtype == y.dtype
248+
and np.all(x.indices == y.indices)
249+
and np.all(x.indptr == y.indptr)
250+
and np.allclose(x.data, y.data)
251+
)

0 commit comments

Comments
 (0)