Skip to content

Commit 568c002

Browse files
JungeAlexanderamueller
authored andcommitted
[MRG + 1] Move n_iter and get_params invariance tests to common estimator_checks (scikit-learn#7677)
* Test get_params invariance in common estimator tests Remove test_get_params_invariance() from `test_common.py` and add test call to _yield_all_tests() in `estimator_checks.py` to make sure that get_params(deep=False) of a given Estimator returns a subset of get_params(deep=True). Compared to test_get_params_invariance(), it is NOT tested anymore whether the given Estimator has an attribute get_params since class BaseEstimator in `base.py` defines such an attribute for each Estimator. Partially addresses issue scikit-learn#7533 Also related to issue scikit-learn#4465 * Move test_transformer_n_iter() to estimator_checks.py Remove the test test_transformer_n_iter() from tests/test_common.py and perform the test logic in utils/estimator_checks.py instead. Specifically, the method _yield_transformer_checks() now yields check_transformer_n_iter() as part of the set of tests for transformers. test_transformer_n_iter() tests that that transformers with an attribute max_iter, return the attribute of n_iter at least 1. Partially addresses latter part of issue scikit-learn#7533 * Move test_non_transformer_estimators_n_iter() to estimator_checks.py Remove the test_non_transformer_estimators_n_iter() from tests/test_common.py; perform the test logic in utils/estimator_checks.py instead. Specifically, the method _yield_non_meta_checks() now yields check_non_transformer_estimators_n_iter(). test_transformer_n_iter() tests that that estimators that are not transformers with an attribute max_iter, return the attribute n_iter of at least 1. NOTE: The current implementation makes said test run for more estimators than before this commit. For some of these estimators, the test fails. This needs to be addressed (see FIXME in line 111-115 of utils/estimator_checks.py for a potential place to start). Partially addresses latter part of issue scikit-learn#7533 * Fix check_non_transformer_estimators_n_iter calls test_transformer_n_iter() test is now only run for estimators where the test is applicable. Partially addresses latter part of issue scikit-learn#7533 * Run check_non_transformer_estimators_n_iter on multi-class estimators To do this, use helper method multioutput_estimator_convert_y_2d. Also remove multi_output parameter from check_non_transformer_estimators_n_iter since this parameter is not used anywhere and corresponding cases should be handled by said helper method. Also, some pep8 line length fixes. * Fix documentation for n_iter tests There was some confusion between attributes and parameters. Also rename n_iter to n_iter_
1 parent 61a9a9b commit 568c002

File tree

2 files changed

+69
-108
lines changed

2 files changed

+69
-108
lines changed

sklearn/tests/test_common.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,8 @@
2828
from sklearn.linear_model.base import LinearClassifierMixin
2929
from sklearn.utils.estimator_checks import (
3030
_yield_all_checks,
31-
CROSS_DECOMPOSITION,
3231
check_parameters_default_constructible,
33-
check_class_weight_balanced_linear_classifier,
34-
check_transformer_n_iter,
35-
check_non_transformer_estimators_n_iter,
36-
check_get_params_invariance)
32+
check_class_weight_balanced_linear_classifier)
3733

3834

3935
def test_all_estimator_no_base_class():
@@ -162,72 +158,3 @@ def test_all_tests_are_importable():
162158
'{0} do not have `tests` subpackages. Perhaps they require '
163159
'__init__.py or an add_subpackage directive in the parent '
164160
'setup.py'.format(missing_tests))
165-
166-
167-
def test_non_transformer_estimators_n_iter():
168-
# Test that all estimators of type which are non-transformer
169-
# and which have an attribute of max_iter, return the attribute
170-
# of n_iter atleast 1.
171-
for est_type in ['regressor', 'classifier', 'cluster']:
172-
regressors = all_estimators(type_filter=est_type)
173-
for name, Estimator in regressors:
174-
# LassoLars stops early for the default alpha=1.0 for
175-
# the iris dataset.
176-
if name == 'LassoLars':
177-
estimator = Estimator(alpha=0.)
178-
else:
179-
with ignore_warnings(category=DeprecationWarning):
180-
estimator = Estimator()
181-
if hasattr(estimator, "max_iter"):
182-
# These models are dependent on external solvers like
183-
# libsvm and accessing the iter parameter is non-trivial.
184-
if name in (['Ridge', 'SVR', 'NuSVR', 'NuSVC',
185-
'RidgeClassifier', 'SVC', 'RandomizedLasso',
186-
'LogisticRegressionCV']):
187-
continue
188-
189-
# Tested in test_transformer_n_iter below
190-
elif (name in CROSS_DECOMPOSITION or
191-
name in ['LinearSVC', 'LogisticRegression']):
192-
continue
193-
194-
else:
195-
# Multitask models related to ENet cannot handle
196-
# if y is mono-output.
197-
yield (_named_check(
198-
check_non_transformer_estimators_n_iter, name),
199-
name, estimator, 'Multi' in name)
200-
201-
202-
def test_transformer_n_iter():
203-
transformers = all_estimators(type_filter='transformer')
204-
for name, Estimator in transformers:
205-
with ignore_warnings(category=DeprecationWarning):
206-
estimator = Estimator()
207-
# Dependent on external solvers and hence accessing the iter
208-
# param is non-trivial.
209-
external_solver = ['Isomap', 'KernelPCA', 'LocallyLinearEmbedding',
210-
'RandomizedLasso', 'LogisticRegressionCV']
211-
212-
if hasattr(estimator, "max_iter") and name not in external_solver:
213-
yield _named_check(
214-
check_transformer_n_iter, name), name, estimator
215-
216-
217-
def test_get_params_invariance():
218-
# Test for estimators that support get_params, that
219-
# get_params(deep=False) is a subset of get_params(deep=True)
220-
# Related to issue #4465
221-
222-
estimators = all_estimators(include_meta_estimators=False,
223-
include_other=True)
224-
for name, Estimator in estimators:
225-
if hasattr(Estimator, 'get_params'):
226-
# If class is deprecated, ignore deprecated warnings
227-
if hasattr(Estimator.__init__, "deprecated_original"):
228-
with ignore_warnings():
229-
yield _named_check(
230-
check_get_params_invariance, name), name, Estimator
231-
else:
232-
yield _named_check(
233-
check_get_params_invariance, name), name, Estimator

sklearn/utils/estimator_checks.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ def _yield_classifier_checks(name, Classifier):
132132
if 'class_weight' in Classifier().get_params().keys():
133133
yield check_class_weight_classifiers
134134

135+
yield check_non_transformer_estimators_n_iter
136+
135137

136138
@ignore_warnings(category=DeprecationWarning)
137139
def check_supervised_y_no_nan(name, Estimator):
@@ -172,6 +174,7 @@ def _yield_regressor_checks(name, Regressor):
172174
if name != "GaussianProcessRegressor":
173175
# Test if NotFittedError is raised
174176
yield check_estimators_unfitted
177+
yield check_non_transformer_estimators_n_iter
175178

176179

177180
def _yield_transformer_checks(name, Transformer):
@@ -186,6 +189,13 @@ def _yield_transformer_checks(name, Transformer):
186189
# basic tests
187190
yield check_transformer_general
188191
yield check_transformers_unfitted
192+
# Dependent on external solvers and hence accessing the iter
193+
# param is non-trivial.
194+
external_solver = ['Isomap', 'KernelPCA', 'LocallyLinearEmbedding',
195+
'RandomizedLasso', 'LogisticRegressionCV']
196+
if name not in external_solver:
197+
yield check_transformer_n_iter
198+
189199

190200

191201
def _yield_clustering_checks(name, Clusterer):
@@ -195,6 +205,7 @@ def _yield_clustering_checks(name, Clusterer):
195205
# let's not test that here.
196206
yield check_clustering
197207
yield check_estimators_partial_fit_n_features
208+
yield check_non_transformer_estimators_n_iter
198209

199210

200211
def _yield_all_checks(name, Estimator):
@@ -218,6 +229,7 @@ def _yield_all_checks(name, Estimator):
218229
yield check_fit2d_1feature
219230
yield check_fit1d_1feature
220231
yield check_fit1d_1sample
232+
yield check_get_params_invariance
221233

222234

223235
def check_estimator(Estimator):
@@ -1477,51 +1489,73 @@ def multioutput_estimator_convert_y_2d(name, y):
14771489

14781490

14791491
@ignore_warnings(category=DeprecationWarning)
1480-
def check_non_transformer_estimators_n_iter(name, estimator,
1481-
multi_output=False):
1482-
# Check if all iterative solvers, run for more than one iteration
1483-
1484-
iris = load_iris()
1485-
X, y_ = iris.data, iris.target
1486-
1487-
if multi_output:
1488-
y_ = np.reshape(y_, (-1, 1))
1492+
def check_non_transformer_estimators_n_iter(name, Estimator):
1493+
# Test that estimators that are not transformers with a parameter
1494+
# max_iter, return the attribute of n_iter_ at least 1.
1495+
1496+
# These models are dependent on external solvers like
1497+
# libsvm and accessing the iter parameter is non-trivial.
1498+
not_run_check_n_iter = ['Ridge', 'SVR', 'NuSVR', 'NuSVC',
1499+
'RidgeClassifier', 'SVC', 'RandomizedLasso',
1500+
'LogisticRegressionCV', 'LinearSVC',
1501+
'LogisticRegression']
1502+
1503+
# Tested in test_transformer_n_iter
1504+
not_run_check_n_iter += CROSS_DECOMPOSITION
1505+
if name in not_run_check_n_iter:
1506+
return
14891507

1490-
set_random_state(estimator, 0)
1491-
if name == 'AffinityPropagation':
1492-
estimator.fit(X)
1508+
# LassoLars stops early for the default alpha=1.0 the iris dataset.
1509+
if name == 'LassoLars':
1510+
estimator = Estimator(alpha=0.)
14931511
else:
1494-
estimator.fit(X, y_)
1512+
estimator = Estimator()
1513+
if hasattr(estimator, 'max_iter'):
1514+
iris = load_iris()
1515+
X, y_ = iris.data, iris.target
1516+
y_ = multioutput_estimator_convert_y_2d(name, y_)
1517+
1518+
set_random_state(estimator, 0)
1519+
if name == 'AffinityPropagation':
1520+
estimator.fit(X)
1521+
else:
1522+
estimator.fit(X, y_)
14951523

1496-
# HuberRegressor depends on scipy.optimize.fmin_l_bfgs_b
1497-
# which doesn't return a n_iter for old versions of SciPy.
1498-
if not (name == 'HuberRegressor' and estimator.n_iter_ is None):
1499-
assert_greater_equal(estimator.n_iter_, 1)
1524+
# HuberRegressor depends on scipy.optimize.fmin_l_bfgs_b
1525+
# which doesn't return a n_iter for old versions of SciPy.
1526+
if not (name == 'HuberRegressor' and estimator.n_iter_ is None):
1527+
assert_greater_equal(estimator.n_iter_, 1)
15001528

15011529

15021530
@ignore_warnings(category=DeprecationWarning)
1503-
def check_transformer_n_iter(name, estimator):
1504-
if name in CROSS_DECOMPOSITION:
1505-
# Check using default data
1506-
X = [[0., 0., 1.], [1., 0., 0.], [2., 2., 2.], [2., 5., 4.]]
1507-
y_ = [[0.1, -0.2], [0.9, 1.1], [0.1, -0.5], [0.3, -0.2]]
1531+
def check_transformer_n_iter(name, Estimator):
1532+
# Test that transformers with a parameter max_iter, return the
1533+
# attribute of n_iter_ at least 1.
1534+
estimator = Estimator()
1535+
if hasattr(estimator, "max_iter"):
1536+
if name in CROSS_DECOMPOSITION:
1537+
# Check using default data
1538+
X = [[0., 0., 1.], [1., 0., 0.], [2., 2., 2.], [2., 5., 4.]]
1539+
y_ = [[0.1, -0.2], [0.9, 1.1], [0.1, -0.5], [0.3, -0.2]]
15081540

1509-
else:
1510-
X, y_ = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
1511-
random_state=0, n_features=2, cluster_std=0.1)
1512-
X -= X.min() - 0.1
1513-
set_random_state(estimator, 0)
1514-
estimator.fit(X, y_)
1541+
else:
1542+
X, y_ = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
1543+
random_state=0, n_features=2, cluster_std=0.1)
1544+
X -= X.min() - 0.1
1545+
set_random_state(estimator, 0)
1546+
estimator.fit(X, y_)
15151547

1516-
# These return a n_iter per component.
1517-
if name in CROSS_DECOMPOSITION:
1518-
for iter_ in estimator.n_iter_:
1519-
assert_greater_equal(iter_, 1)
1520-
else:
1521-
assert_greater_equal(estimator.n_iter_, 1)
1548+
# These return a n_iter per component.
1549+
if name in CROSS_DECOMPOSITION:
1550+
for iter_ in estimator.n_iter_:
1551+
assert_greater_equal(iter_, 1)
1552+
else:
1553+
assert_greater_equal(estimator.n_iter_, 1)
15221554

15231555

1556+
@ignore_warnings(category=DeprecationWarning)
15241557
def check_get_params_invariance(name, estimator):
1558+
# Checks if get_params(deep=False) is a subset of get_params(deep=True)
15251559
class T(BaseEstimator):
15261560
"""Mock classifier
15271561
"""

0 commit comments

Comments
 (0)