Skip to content

Commit 5b6622b

Browse files
FIX improve doc and test for HTMLDocumentationLinkMixin (scikit-learn#29774)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent b3c213b commit 5b6622b

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

sklearn/utils/_estimator_html_repr.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -466,15 +466,31 @@ class _HTMLDocumentationLinkMixin:
466466
Examples
467467
--------
468468
If the default values for `_doc_link_module`, `_doc_link_template` are not suitable,
469-
then you can override them:
469+
then you can override them and provide a method to generate the URL parameters:
470470
>>> from sklearn.base import BaseEstimator
471-
>>> estimator = BaseEstimator()
472-
>>> estimator._doc_link_template = "https://website.com/{single_param}.html"
471+
>>> doc_link_template = "https://address.local/{single_param}.html"
473472
>>> def url_param_generator(estimator):
474473
... return {"single_param": estimator.__class__.__name__}
475-
>>> estimator._doc_link_url_param_generator = url_param_generator
474+
>>> class MyEstimator(BaseEstimator):
475+
... # use "builtins" since it is the associated module when declaring
476+
... # the class in a docstring
477+
... _doc_link_module = "builtins"
478+
... _doc_link_template = doc_link_template
479+
... _doc_link_url_param_generator = url_param_generator
480+
>>> estimator = MyEstimator()
476481
>>> estimator._get_doc_link()
477-
'https://website.com/BaseEstimator.html'
482+
'https://address.local/MyEstimator.html'
483+
484+
If instead of overriding the attributes inside the class definition, you want to
485+
override a class instance, you can use `types.MethodType` to bind the method to the
486+
instance:
487+
>>> import types
488+
>>> estimator = BaseEstimator()
489+
>>> estimator._doc_link_template = doc_link_template
490+
>>> estimator._doc_link_url_param_generator = types.MethodType(
491+
... url_param_generator, estimator)
492+
>>> estimator._get_doc_link()
493+
'https://address.local/BaseEstimator.html'
478494
"""
479495

480496
_doc_link_module = "sklearn"
@@ -530,6 +546,4 @@ def _get_doc_link(self):
530546
return self._doc_link_template.format(
531547
estimator_module=estimator_module, estimator_name=estimator_name
532548
)
533-
return self._doc_link_template.format(
534-
**self._doc_link_url_param_generator(self)
535-
)
549+
return self._doc_link_template.format(**self._doc_link_url_param_generator())

sklearn/utils/tests/test_estimator_html_repr.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import html
22
import locale
33
import re
4+
import types
45
from contextlib import closing
56
from functools import partial
67
from io import StringIO
@@ -453,7 +454,9 @@ def test_html_documentation_link_mixin_sklearn(mock_version):
453454
("prefix.mypackage.mymodule.submodule", "prefix.mypackage.mymodule.submodule"),
454455
],
455456
)
456-
def test_html_documentation_link_mixin_get_doc_link(module_path, expected_module):
457+
def test_html_documentation_link_mixin_get_doc_link_instance(
458+
module_path, expected_module
459+
):
457460
"""Check the behaviour of the `_get_doc_link` with various parameter."""
458461

459462
class FooBar(_HTMLDocumentationLinkMixin):
@@ -469,6 +472,32 @@ class FooBar(_HTMLDocumentationLinkMixin):
469472
assert est._get_doc_link() == f"https://website.com/{expected_module}.FooBar.html"
470473

471474

475+
@pytest.mark.parametrize(
476+
"module_path,expected_module",
477+
[
478+
("prefix.mymodule", "prefix.mymodule"),
479+
("prefix._mymodule", "prefix"),
480+
("prefix.mypackage._mymodule", "prefix.mypackage"),
481+
("prefix.mypackage._mymodule.submodule", "prefix.mypackage"),
482+
("prefix.mypackage.mymodule.submodule", "prefix.mypackage.mymodule.submodule"),
483+
],
484+
)
485+
def test_html_documentation_link_mixin_get_doc_link_class(module_path, expected_module):
486+
"""Check the behaviour of the `_get_doc_link` when `_doc_link_module` and
487+
`_doc_link_template` are defined at the class level and not at the instance
488+
level."""
489+
490+
class FooBar(_HTMLDocumentationLinkMixin):
491+
_doc_link_module = "prefix"
492+
_doc_link_template = (
493+
"https://website.com/{estimator_module}.{estimator_name}.html"
494+
)
495+
496+
FooBar.__module__ = module_path
497+
est = FooBar()
498+
assert est._get_doc_link() == f"https://website.com/{expected_module}.FooBar.html"
499+
500+
472501
def test_html_documentation_link_mixin_get_doc_link_out_of_library():
473502
"""Check the behaviour of the `_get_doc_link` with various parameter."""
474503
mixin = _HTMLDocumentationLinkMixin()
@@ -479,7 +508,7 @@ def test_html_documentation_link_mixin_get_doc_link_out_of_library():
479508
assert mixin._get_doc_link() == ""
480509

481510

482-
def test_html_documentation_link_mixin_doc_link_url_param_generator():
511+
def test_html_documentation_link_mixin_doc_link_url_param_generator_instance():
483512
mixin = _HTMLDocumentationLinkMixin()
484513
# we can bypass the generation by providing our own callable
485514
mixin._doc_link_template = (
@@ -492,11 +521,30 @@ def url_param_generator(estimator):
492521
"another_variable": "value_2",
493522
}
494523

495-
mixin._doc_link_url_param_generator = url_param_generator
524+
mixin._doc_link_url_param_generator = types.MethodType(url_param_generator, mixin)
496525

497526
assert mixin._get_doc_link() == "https://website.com/value_1.value_2.html"
498527

499528

529+
def test_html_documentation_link_mixin_doc_link_url_param_generator_class():
530+
# we can bypass the generation by providing our own callable
531+
532+
def url_param_generator(estimator):
533+
return {
534+
"my_own_variable": "value_1",
535+
"another_variable": "value_2",
536+
}
537+
538+
class FooBar(_HTMLDocumentationLinkMixin):
539+
_doc_link_template = (
540+
"https://website.com/{my_own_variable}.{another_variable}.html"
541+
)
542+
_doc_link_url_param_generator = url_param_generator
543+
544+
estimator = FooBar()
545+
assert estimator._get_doc_link() == "https://website.com/value_1.value_2.html"
546+
547+
500548
@pytest.fixture
501549
def set_non_utf8_locale():
502550
"""Pytest fixture to set non utf-8 locale during the test.

0 commit comments

Comments
 (0)