Skip to content

Commit 2007273

Browse files
authored
Merge pull request #275 from juaml/fix/invert_cv_order
Invert order of CV so it runs the largest first
2 parents d0fcf31 + ffca079 commit 2007273

File tree

6 files changed

+25
-23
lines changed

6 files changed

+25
-23
lines changed

docs/changes/newsfragments/275.enh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Place the final model CV split at the beginning instead of the end of the CV iterator wrapper by `Fede Raimondo`_

julearn/api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -594,13 +594,13 @@ def run_cross_validation(
594594
)
595595

596596
if include_final_model:
597-
# If we include the final model, we need to remove the last item in
597+
# If we include the final model, we need to remove the first item in
598598
# the scores as this is the final model
599-
pipeline = scores["estimator"][-1]
599+
pipeline = scores["estimator"][0]
600600
if return_estimator == "final":
601601
scores.pop("estimator")
602-
scores = {k: v[:-1] for k, v in scores.items()}
603-
fold_sizes = fold_sizes[:-1]
602+
scores = {k: v[1:] for k, v in scores.items()}
603+
fold_sizes = fold_sizes[1:]
604604

605605
n_repeats = getattr(cv_outer, "n_repeats", 1)
606606
n_folds = len(scores["fit_time"]) // n_repeats

julearn/model_selection/final_model_cv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,12 @@ def split(
6868
profitting for joblib calls.
6969
7070
"""
71-
yield from self.cv.split(X, y, groups)
71+
# For the first fold, train on all samples and return only 2 for test
7272
all_inds = np.arange(len(X))
73-
# For the last fold, train on all samples and return only 2 for testing
7473
yield all_inds, all_inds[:2]
7574

75+
yield from self.cv.split(X, y, groups)
76+
7677
def get_n_splits(self) -> int:
7778
"""Get the number of splits.
7879

julearn/model_selection/tests/test_final_model_cv.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def test_final_model_cv() -> None:
3131
all_sk = list(sklearn_cv.split(X, y))
3232

3333
assert len(all_ju) == len(all_sk) + 1
34-
for i in range(10):
35-
assert_array_equal(all_ju[i][0], all_sk[i][0])
36-
assert_array_equal( all_ju[i][1], all_sk[i][1])
34+
for i in range(1, 11):
35+
assert_array_equal(all_ju[i][0], all_sk[i-1][0])
36+
assert_array_equal(all_ju[i][1], all_sk[i-1][1])
3737

38-
assert all_ju[-1][0].shape[0] == n_samples
39-
assert all_ju[-1][1].shape[0] == 2
40-
assert_array_equal(all_ju[-1][0], np.arange(n_samples))
38+
assert all_ju[0][0].shape[0] == n_samples
39+
assert all_ju[0][1].shape[0] == 2
40+
assert_array_equal(all_ju[0][0], np.arange(n_samples))
4141

4242

4343
def test_final_model_cv_mdsum() -> None:

julearn/models/tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_naive_bayes_estimators(
189189
"estimator": DecisionTreeClassifier(random_state=42),
190190
},
191191
),
192-
("gradientboost", GradientBoostingClassifier, {}),
192+
("gradientboost", GradientBoostingClassifier, {"random_state": 42}),
193193
],
194194
)
195195
def test_classificationestimators(

julearn/tests/test_api.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,8 @@ def test_tune_hyperparam_gridsearch(df_iris: pd.DataFrame) -> None:
415415
scoring = "accuracy"
416416

417417
np.random.seed(42)
418-
cv_outer = RepeatedKFold(n_splits=3, n_repeats=2)
419-
cv_inner = RepeatedKFold(n_splits=3, n_repeats=2)
418+
cv_outer = RepeatedKFold(n_splits=3, n_repeats=2, random_state=9)
419+
cv_inner = RepeatedKFold(n_splits=3, n_repeats=2, random_state=10)
420420

421421
model_params = {"svm__C": [0.01, 0.001]}
422422
search_params = {"cv": cv_inner}
@@ -438,8 +438,8 @@ def test_tune_hyperparam_gridsearch(df_iris: pd.DataFrame) -> None:
438438

439439
# Now do the same with scikit-learn
440440
np.random.seed(42)
441-
cv_outer = RepeatedKFold(n_splits=3, n_repeats=2)
442-
cv_inner = RepeatedKFold(n_splits=3, n_repeats=2)
441+
cv_outer = RepeatedKFold(n_splits=3, n_repeats=2, random_state=9)
442+
cv_inner = RepeatedKFold(n_splits=3, n_repeats=2, random_state=10)
443443

444444
clf = make_pipeline(SVC())
445445
gs = GridSearchCV(
@@ -672,8 +672,8 @@ def test_tune_hyperparams_multiple_grid(df_iris: pd.DataFrame) -> None:
672672
scoring = "accuracy"
673673

674674
np.random.seed(42)
675-
cv_outer = RepeatedKFold(n_splits=2, n_repeats=1)
676-
cv_inner = RepeatedKFold(n_splits=2, n_repeats=1)
675+
cv_outer = RepeatedKFold(n_splits=2, n_repeats=1, random_state=9)
676+
cv_inner = RepeatedKFold(n_splits=2, n_repeats=1, random_state=10)
677677

678678
search_params = {"cv": cv_inner}
679679
actual1, actual_estimator1 = run_cross_validation(
@@ -701,8 +701,8 @@ def test_tune_hyperparams_multiple_grid(df_iris: pd.DataFrame) -> None:
701701
)
702702

703703
np.random.seed(42)
704-
cv_outer = RepeatedKFold(n_splits=2, n_repeats=1)
705-
cv_inner = RepeatedKFold(n_splits=2, n_repeats=1)
704+
cv_outer = RepeatedKFold(n_splits=2, n_repeats=1, random_state=9)
705+
cv_inner = RepeatedKFold(n_splits=2, n_repeats=1, random_state=10)
706706
search_params = {"cv": cv_inner}
707707
actual2, actual_estimator2 = run_cross_validation(
708708
X=X,
@@ -718,8 +718,8 @@ def test_tune_hyperparams_multiple_grid(df_iris: pd.DataFrame) -> None:
718718

719719
# Now do the same with scikit-learn
720720
np.random.seed(42)
721-
cv_outer = RepeatedKFold(n_splits=2, n_repeats=1)
722-
cv_inner = RepeatedKFold(n_splits=2, n_repeats=1)
721+
cv_outer = RepeatedKFold(n_splits=2, n_repeats=1, random_state=9)
722+
cv_inner = RepeatedKFold(n_splits=2, n_repeats=1, random_state=10)
723723

724724
clf = make_pipeline(SVC())
725725
grid = [

0 commit comments

Comments
 (0)