Skip to content

Commit c55d064

Browse files
deepyamanadam2392
andauthored
ENH Adjust estimator representation beyond maxlevels (scikit-learn#29492)
Co-authored-by: Adam Li <adam2392@gmail.com>
1 parent 3d5e243 commit c55d064

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

sklearn/utils/_pprint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
427427
if issubclass(typ, BaseEstimator):
428428
objid = id(object)
429429
if maxlevels and level >= maxlevels:
430-
return "{...}", False, objid in context
430+
return f"{typ.__name__}(...)", False, objid in context
431431
if objid in context:
432432
return pprint._recursion(object), False, True
433433
context[objid] = 1

sklearn/utils/tests/test_pprint.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pprint import PrettyPrinter
33

44
import numpy as np
5+
import pytest
56

67
from sklearn.utils._pprint import _EstimatorPrettyPrinter
78
from sklearn.linear_model import LogisticRegressionCV
@@ -346,6 +347,24 @@ def test_deeply_nested(print_changed_only_false):
346347
assert rfe.__repr__() == expected
347348

348349

350+
@pytest.mark.parametrize(
351+
("print_changed_only", "expected"),
352+
[
353+
(True, "RFE(estimator=RFE(...))"),
354+
(
355+
False,
356+
"RFE(estimator=RFE(...), n_features_to_select=None, step=1, verbose=0)",
357+
),
358+
],
359+
)
360+
def test_print_estimator_max_depth(print_changed_only, expected):
361+
with config_context(print_changed_only=print_changed_only):
362+
pp = _EstimatorPrettyPrinter(depth=1)
363+
364+
rfe = RFE(RFE(RFE(RFE(RFE(LogisticRegression())))))
365+
assert pp.pformat(rfe) == expected
366+
367+
349368
def test_gridsearch(print_changed_only_false):
350369
# render a gridsearch
351370
param_grid = [

0 commit comments

Comments
 (0)