Skip to content

Commit d077f82

Browse files
ENH: Display parameters in HTML representation (scikit-learn#30763)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 19a6e61 commit d077f82

File tree

21 files changed

+595
-181
lines changed

21 files changed

+595
-181
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :class:`base.BaseEstimator` now has a parameter table added to the
2+
estimators HTML representation that can be visualized with jupyter.
3+
By :user:`Guillaume Lemaitre <glemaitre>` and
4+
:user:`Dea María Léon <DeaMariaLeon>`

sklearn/base.py

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from . import __version__
1717
from ._config import config_context, get_config
1818
from .exceptions import InconsistentVersionWarning
19-
from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr
2019
from .utils._metadata_requests import _MetadataRequester, _routing_enabled
20+
from .utils._missing import is_scalar_nan
2121
from .utils._param_validation import validate_parameter_constraints
22+
from .utils._repr_html.base import ReprHTMLMixin, _HTMLDocumentationLinkMixin
23+
from .utils._repr_html.estimator import estimator_html_repr
24+
from .utils._repr_html.params import ParamsDict
2225
from .utils._set_output import _SetOutputMixin
2326
from .utils._tags import (
2427
ClassifierTags,
@@ -150,7 +153,7 @@ def _clone_parametrized(estimator, *, safe=True):
150153
return new_object
151154

152155

153-
class BaseEstimator(_HTMLDocumentationLinkMixin, _MetadataRequester):
156+
class BaseEstimator(ReprHTMLMixin, _HTMLDocumentationLinkMixin, _MetadataRequester):
154157
"""Base class for all estimators in scikit-learn.
155158
156159
Inheriting from this class provides default implementations of:
@@ -194,6 +197,8 @@ class BaseEstimator(_HTMLDocumentationLinkMixin, _MetadataRequester):
194197
array([3, 3, 3])
195198
"""
196199

200+
_html_repr = estimator_html_repr
201+
197202
@classmethod
198203
def _get_param_names(cls):
199204
"""Get parameter names for the estimator"""
@@ -249,6 +254,64 @@ def get_params(self, deep=True):
249254
out[key] = value
250255
return out
251256

257+
def _get_params_html(self, deep=True):
258+
"""
259+
Get parameters for this estimator with a specific HTML representation.
260+
261+
Parameters
262+
----------
263+
deep : bool, default=True
264+
If True, will return the parameters for this estimator and
265+
contained subobjects that are estimators.
266+
267+
Returns
268+
-------
269+
params : ParamsDict
270+
Parameter names mapped to their values. We return a `ParamsDict`
271+
dictionary, which renders a specific HTML representation in table
272+
form.
273+
"""
274+
out = self.get_params(deep=deep)
275+
276+
init_func = getattr(self.__init__, "deprecated_original", self.__init__)
277+
init_default_params = inspect.signature(init_func).parameters
278+
init_default_params = {
279+
name: param.default for name, param in init_default_params.items()
280+
}
281+
282+
def is_non_default(param_name, param_value):
283+
"""Finds the parameters that have been set by the user."""
284+
if param_name not in init_default_params:
285+
# happens if k is part of a **kwargs
286+
return True
287+
if init_default_params[param_name] == inspect._empty:
288+
# k has no default value
289+
return True
290+
# avoid calling repr on nested estimators
291+
if isinstance(param_value, BaseEstimator) and type(param_value) is not type(
292+
init_default_params[param_name]
293+
):
294+
return True
295+
296+
if param_value != init_default_params[param_name] and not (
297+
is_scalar_nan(init_default_params[param_name])
298+
and is_scalar_nan(param_value)
299+
):
300+
return True
301+
return False
302+
303+
# reorder the parameters from `self.get_params` using the `__init__`
304+
# signature
305+
remaining_params = [name for name in out if name not in init_default_params]
306+
ordered_out = {name: out[name] for name in init_default_params if name in out}
307+
ordered_out.update({name: out[name] for name in remaining_params})
308+
309+
non_default_ls = tuple(
310+
[name for name, value in ordered_out.items() if is_non_default(name, value)]
311+
)
312+
313+
return ParamsDict(ordered_out, non_default=non_default_ls)
314+
252315
def set_params(self, **params):
253316
"""Set the parameters of this estimator.
254317
@@ -409,36 +472,6 @@ class attribute, which is a dictionary `param_name: list of constraints`. See
409472
caller_name=self.__class__.__name__,
410473
)
411474

412-
@property
413-
def _repr_html_(self):
414-
"""HTML representation of estimator.
415-
416-
This is redundant with the logic of `_repr_mimebundle_`. The latter
417-
should be favored in the long term, `_repr_html_` is only
418-
implemented for consumers who do not interpret `_repr_mimbundle_`.
419-
"""
420-
if get_config()["display"] != "diagram":
421-
raise AttributeError(
422-
"_repr_html_ is only defined when the "
423-
"'display' configuration option is set to "
424-
"'diagram'"
425-
)
426-
return self._repr_html_inner
427-
428-
def _repr_html_inner(self):
429-
"""This function is returned by the @property `_repr_html_` to make
430-
`hasattr(estimator, "_repr_html_") return `True` or `False` depending
431-
on `get_config()["display"]`.
432-
"""
433-
return estimator_html_repr(self)
434-
435-
def _repr_mimebundle_(self, **kwargs):
436-
"""Mime bundle used by jupyter kernels to display estimator"""
437-
output = {"text/plain": repr(self)}
438-
if get_config()["display"] == "diagram":
439-
output["text/html"] = estimator_html_repr(self)
440-
return output
441-
442475

443476
class ClassifierMixin:
444477
"""Mixin class for all classifiers in scikit-learn.

sklearn/compose/_column_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from ..pipeline import _fit_transform_one, _name_estimators, _transform_one
2121
from ..preprocessing import FunctionTransformer
2222
from ..utils import Bunch
23-
from ..utils._estimator_html_repr import _VisualBlock
2423
from ..utils._indexing import _determine_key_type, _get_column_indices, _safe_indexing
2524
from ..utils._metadata_requests import METHODS
2625
from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions
26+
from ..utils._repr_html.estimator import _VisualBlock
2727
from ..utils._set_output import (
2828
_get_container_adapter,
2929
_get_output_config,

sklearn/ensemble/_stacking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from ..model_selection import check_cv, cross_val_predict
2525
from ..preprocessing import LabelEncoder
2626
from ..utils import Bunch
27-
from ..utils._estimator_html_repr import _VisualBlock
2827
from ..utils._param_validation import HasMethods, StrOptions
28+
from ..utils._repr_html.estimator import _VisualBlock
2929
from ..utils.metadata_routing import (
3030
MetadataRouter,
3131
MethodMapping,

sklearn/ensemble/_voting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from ..exceptions import NotFittedError
2525
from ..preprocessing import LabelEncoder
2626
from ..utils import Bunch
27-
from ..utils._estimator_html_repr import _VisualBlock
2827
from ..utils._param_validation import StrOptions
28+
from ..utils._repr_html.estimator import _VisualBlock
2929
from ..utils.metadata_routing import (
3030
MetadataRouter,
3131
MethodMapping,

sklearn/model_selection/_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
get_scorer_names,
3232
)
3333
from ..utils import Bunch, check_random_state
34-
from ..utils._estimator_html_repr import _VisualBlock
3534
from ..utils._param_validation import HasMethods, Interval, StrOptions
35+
from ..utils._repr_html.estimator import _VisualBlock
3636
from ..utils._tags import get_tags
3737
from ..utils.metadata_routing import (
3838
MetadataRouter,

sklearn/model_selection/tests/test_search.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2662,21 +2662,21 @@ def test_search_html_repr():
26622662
search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=False)
26632663
with config_context(display="diagram"):
26642664
repr_html = search_cv._repr_html_()
2665-
assert "<pre>DummyClassifier()</pre>" in repr_html
2665+
assert "<div>DummyClassifier</div>" in repr_html
26662666

26672667
# Fitted with `refit=False` shows the original pipeline
26682668
search_cv.fit(X, y)
26692669
with config_context(display="diagram"):
26702670
repr_html = search_cv._repr_html_()
2671-
assert "<pre>DummyClassifier()</pre>" in repr_html
2671+
assert "<div>DummyClassifier</div>" in repr_html
26722672

26732673
# Fitted with `refit=True` shows the best estimator
26742674
search_cv = GridSearchCV(pipeline, param_grid=param_grid, refit=True)
26752675
search_cv.fit(X, y)
26762676
with config_context(display="diagram"):
26772677
repr_html = search_cv._repr_html_()
2678-
assert "<pre>DummyClassifier()</pre>" not in repr_html
2679-
assert "<pre>LogisticRegression()</pre>" in repr_html
2678+
assert "<div>DummyClassifier</div>" not in repr_html
2679+
assert "<div>LogisticRegression</div>" in repr_html
26802680

26812681

26822682
# Metadata Routing Tests

sklearn/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from .exceptions import NotFittedError
1717
from .preprocessing import FunctionTransformer
1818
from .utils import Bunch
19-
from .utils._estimator_html_repr import _VisualBlock
2019
from .utils._metadata_requests import METHODS
2120
from .utils._param_validation import HasMethods, Hidden
21+
from .utils._repr_html.estimator import _VisualBlock
2222
from .utils._set_output import (
2323
_get_container_adapter,
2424
_safe_set_output,

sklearn/preprocessing/_function_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import numpy as np
88

99
from ..base import BaseEstimator, TransformerMixin, _fit_context
10-
from ..utils._estimator_html_repr import _VisualBlock
1110
from ..utils._param_validation import StrOptions
11+
from ..utils._repr_html.estimator import _VisualBlock
1212
from ..utils._set_output import (
1313
_get_adapter_from_container,
1414
_get_output_config,

sklearn/tests/test_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,3 +992,11 @@ def predict(self, X, prop=None):
992992
with warnings.catch_warnings(record=True) as record:
993993
CustomOutlierDetector().set_predict_request(prop=True).fit_predict([[1]], [1])
994994
assert len(record) == 0
995+
996+
997+
def test_get_params_html():
998+
"""Check the behaviour of the `_get_params_html` method."""
999+
est = MyEstimator(empty="test")
1000+
1001+
assert est._get_params_html() == {"l1": 0, "empty": "test"}
1002+
assert est._get_params_html().non_default == ("empty",)

0 commit comments

Comments
 (0)