Skip to content

Commit f916449

Browse files
authored
Merge pull request scikit-learn#7317 from amueller/common_test_names
[MRG+1] make more explicit which checks are run
2 parents 49fb295 + 7fc4176 commit f916449

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

sklearn/tests/test_common.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
check_class_weight_balanced_linear_classifier,
3434
check_transformer_n_iter,
3535
check_non_transformer_estimators_n_iter,
36-
check_get_params_invariance)
36+
check_get_params_invariance,
37+
_set_test_name)
3738

3839

3940
def test_all_estimator_no_base_class():
@@ -55,7 +56,8 @@ def test_all_estimators():
5556

5657
for name, Estimator in estimators:
5758
# some can just not be sensibly default constructed
58-
yield check_parameters_default_constructible, name, Estimator
59+
yield (_set_test_name(check_parameters_default_constructible, name),
60+
name, Estimator)
5961

6062

6163
def test_non_meta_estimators():
@@ -70,9 +72,9 @@ def test_non_meta_estimators():
7072
if issubclass(Estimator, ProjectedGradientNMF):
7173
# The ProjectedGradientNMF class is deprecated
7274
with ignore_warnings():
73-
yield check, name, Estimator
75+
yield _set_test_name(check, name), name, Estimator
7476
else:
75-
yield check, name, Estimator
77+
yield _set_test_name(check, name), name, Estimator
7678

7779

7880
def test_configure():
@@ -114,7 +116,8 @@ def test_class_weight_balanced_linear_classifiers():
114116
issubclass(clazz, LinearClassifierMixin))]
115117

116118
for name, Classifier in linear_classifiers:
117-
yield check_class_weight_balanced_linear_classifier, name, Classifier
119+
yield _set_test_name(check_class_weight_balanced_linear_classifier,
120+
name), name, Classifier
118121

119122

120123
@ignore_warnings
@@ -196,8 +199,9 @@ def test_non_transformer_estimators_n_iter():
196199
else:
197200
# Multitask models related to ENet cannot handle
198201
# if y is mono-output.
199-
yield (check_non_transformer_estimators_n_iter,
200-
name, estimator, 'Multi' in name)
202+
yield (_set_test_name(
203+
check_non_transformer_estimators_n_iter, name),
204+
name, estimator, 'Multi' in name)
201205

202206

203207
def test_transformer_n_iter():
@@ -218,9 +222,12 @@ def test_transformer_n_iter():
218222
if isinstance(estimator, ProjectedGradientNMF):
219223
# The ProjectedGradientNMF class is deprecated
220224
with ignore_warnings():
221-
yield check_transformer_n_iter, name, estimator
225+
yield _set_test_name(
226+
check_transformer_n_iter, name), name, estimator
222227
else:
223-
yield check_transformer_n_iter, name, estimator
228+
yield _set_test_name(
229+
check_transformer_n_iter, name), name, estimator
230+
224231

225232
def test_get_params_invariance():
226233
# Test for estimators that support get_params, that
@@ -234,6 +241,8 @@ def test_get_params_invariance():
234241
# If class is deprecated, ignore deprecated warnings
235242
if hasattr(Estimator.__init__, "deprecated_original"):
236243
with ignore_warnings():
237-
yield check_get_params_invariance, name, Estimator
244+
yield _set_test_name(
245+
check_get_params_invariance, name), name, Estimator
238246
else:
239-
yield check_get_params_invariance, name, Estimator
247+
yield _set_test_name(
248+
check_get_params_invariance, name), name, Estimator

sklearn/utils/estimator_checks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@
7575
"GradientBoostingClassifier", "GradientBoostingRegressor"]
7676

7777

78+
def _set_test_name(function, name):
79+
function.description = ("sklearn.tests.test_common.{0}({1})".format(
80+
function.__name__, name))
81+
return function
82+
83+
7884
def _yield_non_meta_checks(name, Estimator):
7985
yield check_estimators_dtypes
8086
yield check_fit_score_takes_y

0 commit comments

Comments
 (0)