|
31 | 31 | from sklearn.metrics import get_scorer, log_loss
|
32 | 32 | from sklearn.model_selection import (
|
33 | 33 | GridSearchCV,
|
| 34 | + LeaveOneGroupOut, |
34 | 35 | StratifiedKFold,
|
35 | 36 | cross_val_score,
|
36 | 37 | train_test_split,
|
@@ -775,86 +776,167 @@ def test_logistic_regressioncv_class_weights(weight, class_weight, global_random
|
775 | 776 | )
|
776 | 777 |
|
777 | 778 |
|
778 |
| -def test_logistic_regression_sample_weights(): |
| 779 | +@pytest.mark.parametrize("problem", ("single", "cv")) |
| 780 | +@pytest.mark.parametrize( |
| 781 | + "solver", ("lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga") |
| 782 | +) |
| 783 | +def test_logistic_regression_sample_weights(problem, solver, global_random_seed): |
| 784 | + n_samples_per_cv_group = 200 |
| 785 | + n_cv_groups = 3 |
| 786 | + |
779 | 787 | X, y = make_classification(
|
780 |
| - n_samples=20, n_features=5, n_informative=3, n_classes=2, random_state=0 |
| 788 | + n_samples=n_samples_per_cv_group * n_cv_groups, |
| 789 | + n_features=5, |
| 790 | + n_informative=3, |
| 791 | + n_classes=2, |
| 792 | + n_redundant=0, |
| 793 | + random_state=global_random_seed, |
781 | 794 | )
|
| 795 | + rng = np.random.RandomState(global_random_seed) |
| 796 | + sw = np.ones(y.shape[0]) |
| 797 | + |
| 798 | + kw_weighted = { |
| 799 | + "random_state": global_random_seed, |
| 800 | + "fit_intercept": False, |
| 801 | + "max_iter": 100_000 if solver.startswith("sag") else 1_000, |
| 802 | + "tol": 1e-8, |
| 803 | + } |
| 804 | + kw_repeated = kw_weighted.copy() |
| 805 | + sw[:n_samples_per_cv_group] = rng.randint(0, 5, size=n_samples_per_cv_group) |
| 806 | + X_repeated = np.repeat(X, sw.astype(int), axis=0) |
| 807 | + y_repeated = np.repeat(y, sw.astype(int), axis=0) |
| 808 | + |
| 809 | + if problem == "single": |
| 810 | + LR = LogisticRegression |
| 811 | + elif problem == "cv": |
| 812 | + LR = LogisticRegressionCV |
| 813 | + # We weight the first fold 2 times more. |
| 814 | + groups_weighted = np.concatenate( |
| 815 | + [ |
| 816 | + np.full(n_samples_per_cv_group, 0), |
| 817 | + np.full(n_samples_per_cv_group, 1), |
| 818 | + np.full(n_samples_per_cv_group, 2), |
| 819 | + ] |
| 820 | + ) |
| 821 | + splits_weighted = list(LeaveOneGroupOut().split(X, groups=groups_weighted)) |
| 822 | + kw_weighted.update({"Cs": 100, "cv": splits_weighted}) |
| 823 | + |
| 824 | + groups_repeated = np.repeat(groups_weighted, sw.astype(int), axis=0) |
| 825 | + splits_repeated = list( |
| 826 | + LeaveOneGroupOut().split(X_repeated, groups=groups_repeated) |
| 827 | + ) |
| 828 | + kw_repeated.update({"Cs": 100, "cv": splits_repeated}) |
| 829 | + |
| 830 | + clf_sw_weighted = LR(solver=solver, **kw_weighted) |
| 831 | + clf_sw_repeated = LR(solver=solver, **kw_repeated) |
| 832 | + |
| 833 | + if solver == "lbfgs": |
| 834 | + # lbfgs has convergence issues on the data but this should not impact |
| 835 | + # the quality of the results. |
| 836 | + with warnings.catch_warnings(): |
| 837 | + warnings.simplefilter("ignore", ConvergenceWarning) |
| 838 | + clf_sw_weighted.fit(X, y, sample_weight=sw) |
| 839 | + clf_sw_repeated.fit(X_repeated, y_repeated) |
| 840 | + |
| 841 | + else: |
| 842 | + clf_sw_weighted.fit(X, y, sample_weight=sw) |
| 843 | + clf_sw_repeated.fit(X_repeated, y_repeated) |
| 844 | + |
| 845 | + if problem == "cv": |
| 846 | + assert_allclose(clf_sw_weighted.scores_[1], clf_sw_repeated.scores_[1]) |
| 847 | + assert_allclose(clf_sw_weighted.coef_, clf_sw_repeated.coef_, atol=1e-5) |
| 848 | + |
| 849 | + |
| 850 | +@pytest.mark.parametrize( |
| 851 | + "solver", ("lbfgs", "newton-cg", "newton-cholesky", "sag", "saga") |
| 852 | +) |
| 853 | +def test_logistic_regression_solver_class_weights(solver, global_random_seed): |
| 854 | + # Test that passing class_weight as [1, 2] is the same as |
| 855 | + # passing class weight = [1,1] but adjusting sample weights |
| 856 | + # to be 2 for all instances of class 1. |
| 857 | + |
| 858 | + X, y = make_classification( |
| 859 | + n_samples=300, |
| 860 | + n_features=5, |
| 861 | + n_informative=3, |
| 862 | + n_classes=2, |
| 863 | + random_state=global_random_seed, |
| 864 | + ) |
| 865 | + |
782 | 866 | sample_weight = y + 1
|
783 | 867 |
|
784 |
| - for LR in [LogisticRegression, LogisticRegressionCV]: |
785 |
| - kw = {"random_state": 42, "fit_intercept": False} |
786 |
| - if LR is LogisticRegressionCV: |
787 |
| - kw.update({"Cs": 3, "cv": 3}) |
788 |
| - |
789 |
| - # Test that passing sample_weight as ones is the same as |
790 |
| - # not passing them at all (default None) |
791 |
| - for solver in ["lbfgs", "liblinear"]: |
792 |
| - clf_sw_none = LR(solver=solver, **kw) |
793 |
| - clf_sw_ones = LR(solver=solver, **kw) |
794 |
| - clf_sw_none.fit(X, y) |
795 |
| - clf_sw_ones.fit(X, y, sample_weight=np.ones(y.shape[0])) |
796 |
| - assert_allclose(clf_sw_none.coef_, clf_sw_ones.coef_, rtol=1e-4) |
797 |
| - |
798 |
| - # Test that sample weights work the same with the lbfgs, |
799 |
| - # newton-cg, newton-cholesky and 'sag' solvers |
800 |
| - clf_sw_lbfgs = LR(**kw, tol=1e-5) |
801 |
| - clf_sw_lbfgs.fit(X, y, sample_weight=sample_weight) |
802 |
| - for solver in set(SOLVERS) - set(["lbfgs"]): |
803 |
| - clf_sw = LR(solver=solver, tol=1e-10 if solver == "sag" else 1e-5, **kw) |
804 |
| - # ignore convergence warning due to small dataset with sag |
805 |
| - with ignore_warnings(): |
806 |
| - clf_sw.fit(X, y, sample_weight=sample_weight) |
807 |
| - assert_allclose(clf_sw_lbfgs.coef_, clf_sw.coef_, rtol=1e-4) |
808 |
| - |
809 |
| - # Test that passing class_weight as [1,2] is the same as |
810 |
| - # passing class weight = [1,1] but adjusting sample weights |
811 |
| - # to be 2 for all instances of class 2 |
812 |
| - for solver in ["lbfgs", "liblinear"]: |
813 |
| - clf_cw_12 = LR(solver=solver, class_weight={0: 1, 1: 2}, **kw) |
814 |
| - clf_cw_12.fit(X, y) |
815 |
| - clf_sw_12 = LR(solver=solver, **kw) |
816 |
| - clf_sw_12.fit(X, y, sample_weight=sample_weight) |
817 |
| - assert_allclose(clf_cw_12.coef_, clf_sw_12.coef_, rtol=1e-4) |
| 868 | + kw_weighted = { |
| 869 | + "random_state": global_random_seed, |
| 870 | + "fit_intercept": False, |
| 871 | + "max_iter": 100_000, |
| 872 | + "tol": 1e-8, |
| 873 | + } |
| 874 | + clf_cw_12 = LogisticRegression( |
| 875 | + solver=solver, class_weight={0: 1, 1: 2}, **kw_weighted |
| 876 | + ) |
| 877 | + clf_cw_12.fit(X, y) |
| 878 | + clf_sw_12 = LogisticRegression(solver=solver, **kw_weighted) |
| 879 | + clf_sw_12.fit(X, y, sample_weight=sample_weight) |
| 880 | + assert_allclose(clf_cw_12.coef_, clf_sw_12.coef_, atol=1e-6) |
| 881 | + |
818 | 882 |
|
| 883 | +def test_sample_and_class_weight_equivalence_liblinear(global_random_seed): |
819 | 884 | # Test the above for l1 penalty and l2 penalty with dual=True.
|
820 | 885 | # since the patched liblinear code is different.
|
| 886 | + |
| 887 | + X, y = make_classification( |
| 888 | + n_samples=300, |
| 889 | + n_features=5, |
| 890 | + n_informative=3, |
| 891 | + n_classes=2, |
| 892 | + random_state=global_random_seed, |
| 893 | + ) |
| 894 | + |
| 895 | + sample_weight = y + 1 |
| 896 | + |
821 | 897 | clf_cw = LogisticRegression(
|
822 | 898 | solver="liblinear",
|
823 | 899 | fit_intercept=False,
|
824 | 900 | class_weight={0: 1, 1: 2},
|
825 | 901 | penalty="l1",
|
826 |
| - tol=1e-5, |
827 |
| - random_state=42, |
| 902 | + max_iter=10_000, |
| 903 | + tol=1e-12, |
| 904 | + random_state=global_random_seed, |
828 | 905 | )
|
829 | 906 | clf_cw.fit(X, y)
|
830 | 907 | clf_sw = LogisticRegression(
|
831 | 908 | solver="liblinear",
|
832 | 909 | fit_intercept=False,
|
833 | 910 | penalty="l1",
|
834 |
| - tol=1e-5, |
835 |
| - random_state=42, |
| 911 | + max_iter=10_000, |
| 912 | + tol=1e-12, |
| 913 | + random_state=global_random_seed, |
836 | 914 | )
|
837 | 915 | clf_sw.fit(X, y, sample_weight)
|
838 |
| - assert_array_almost_equal(clf_cw.coef_, clf_sw.coef_, decimal=4) |
| 916 | + assert_allclose(clf_cw.coef_, clf_sw.coef_, atol=1e-10) |
839 | 917 |
|
840 | 918 | clf_cw = LogisticRegression(
|
841 | 919 | solver="liblinear",
|
842 | 920 | fit_intercept=False,
|
843 | 921 | class_weight={0: 1, 1: 2},
|
844 | 922 | penalty="l2",
|
| 923 | + max_iter=10_000, |
| 924 | + tol=1e-12, |
845 | 925 | dual=True,
|
846 |
| - random_state=42, |
| 926 | + random_state=global_random_seed, |
847 | 927 | )
|
848 | 928 | clf_cw.fit(X, y)
|
849 | 929 | clf_sw = LogisticRegression(
|
850 | 930 | solver="liblinear",
|
851 | 931 | fit_intercept=False,
|
852 | 932 | penalty="l2",
|
| 933 | + max_iter=10_000, |
| 934 | + tol=1e-12, |
853 | 935 | dual=True,
|
854 |
| - random_state=42, |
| 936 | + random_state=global_random_seed, |
855 | 937 | )
|
856 | 938 | clf_sw.fit(X, y, sample_weight)
|
857 |
| - assert_array_almost_equal(clf_cw.coef_, clf_sw.coef_, decimal=4) |
| 939 | + assert_allclose(clf_cw.coef_, clf_sw.coef_, atol=1e-10) |
858 | 940 |
|
859 | 941 |
|
860 | 942 | def _compute_class_weight_dictionary(y):
|
|
0 commit comments