Skip to content

Commit 6e4c2ff

Browse files
committed
Modified version conflict.
1 parent 61019f9 commit 6e4c2ff

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

hyperts/framework/stats/sktime_ex/_sfa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def _igb(self, dft, y):
519519
breakpoints = np.zeros((self.word_length, self.alphabet_size))
520520
clf = DecisionTreeClassifier(
521521
criterion="entropy",
522-
max_depth=np.log2(self.alphabet_size),
522+
max_depth=int(np.floor(np.log2(self.alphabet_size))),
523523
max_leaf_nodes=self.alphabet_size,
524524
random_state=1,
525525
)

hyperts/framework/stats/sktime_ex/_tsf.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,22 @@ def __init__(
6666
# We need to add is-fitted state when inheriting from scikit-learn
6767
self._is_fitted = False
6868

69+
@property
70+
def _estimator(self):
71+
"""Access first parameter in self, self inheriting from sklearn BaseForest.
72+
73+
The attribute was renamed from base_estimator to estimator in sklearn 1.2.0.
74+
"""
75+
import sklearn
76+
from packaging.specifiers import SpecifierSet
77+
78+
sklearn_version = sklearn.__version__
79+
80+
if sklearn_version in SpecifierSet(">=1.2.0"):
81+
return self.estimator
82+
else:
83+
return self.base_estimator
84+
6985
def fit(self, X, y):
7086
"""Build a forest of trees from the training set (X, y).
7187
@@ -110,7 +126,7 @@ def fit(self, X, y):
110126

111127
self.estimators_ = Parallel(n_jobs=n_jobs)(
112128
delayed(_fit_estimator)(
113-
_clone_estimator(self.base_estimator, rng), X, y, self.intervals_[i]
129+
_clone_estimator(self._estimator, rng), X, y, self.intervals_[i]
114130
)
115131
for i in range(self.n_estimators)
116132
)

0 commit comments

Comments
 (0)