Skip to content

Commit 107e009

Browse files
authored
FIX set CategoricalNB().__sklearn_tags__.input_tags.categorical to True (scikit-learn#31556)
1 parent 1d5e692 commit 107e009

File tree

5 files changed

+17
-11
lines changed

5 files changed

+17
-11
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :class:`naive_bayes.CategoricalNB` now correctly declares that it accepts
2+
categorical features in the tags returned by its `__sklearn_tags__` method.
3+
By :user:`Olivier Grisel <ogrisel>`

sklearn/naive_bayes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,6 +1433,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
14331433

14341434
def __sklearn_tags__(self):
14351435
tags = super().__sklearn_tags__()
1436+
tags.input_tags.categorical = True
14361437
tags.input_tags.sparse = False
14371438
tags.input_tags.positive_only = True
14381439
return tags

sklearn/tests/test_naive_bayes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,3 +968,12 @@ def test_predict_joint_proba(Estimator, global_random_seed):
968968
log_prob_x = logsumexp(jll, axis=1)
969969
log_prob_x_y = jll - np.atleast_2d(log_prob_x).T
970970
assert_allclose(est.predict_log_proba(X2), log_prob_x_y)
971+
972+
973+
@pytest.mark.parametrize("Estimator", ALL_NAIVE_BAYES_CLASSES)
974+
def test_categorical_input_tag(Estimator):
975+
tags = Estimator().__sklearn_tags__()
976+
if Estimator is CategoricalNB:
977+
assert tags.input_tags.categorical
978+
else:
979+
assert not tags.input_tags.categorical

sklearn/utils/_test_common/instance_generator.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@
144144
MultiOutputRegressor,
145145
RegressorChain,
146146
)
147-
from sklearn.naive_bayes import CategoricalNB
148147
from sklearn.neighbors import (
149148
KernelDensity,
150149
KNeighborsClassifier,
@@ -898,15 +897,6 @@ def _yield_instances_for_check(check, estimator_orig):
898897
"sample_weight is not equivalent to removing/repeating samples."
899898
),
900899
},
901-
CategoricalNB: {
902-
# TODO: fix sample_weight handling of this estimator, see meta-issue #16298
903-
"check_sample_weight_equivalence_on_dense_data": (
904-
"sample_weight is not equivalent to removing/repeating samples."
905-
),
906-
"check_sample_weight_equivalence_on_sparse_data": (
907-
"sample_weight is not equivalent to removing/repeating samples."
908-
),
909-
},
910900
ColumnTransformer: {
911901
"check_estimators_empty_data_messages": "FIXME",
912902
"check_estimators_nan_inf": "FIXME",

sklearn/utils/estimator_checks.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3997,7 +3997,10 @@ def check_positive_only_tag_during_fit(name, estimator_orig):
39973997
y = _enforce_estimator_tags_y(estimator, y)
39983998
set_random_state(estimator, 0)
39993999
X = _enforce_estimator_tags_X(estimator, X)
4000-
X -= X.mean()
4000+
# Make sure that the dtype of X stays unchanged: for instance estimator
4001+
# that expect categorical inputs typically expected integer-based encoded
4002+
# categories.
4003+
X -= X.mean().astype(X.dtype)
40014004

40024005
if tags.input_tags.positive_only:
40034006
with raises(ValueError, match="Negative values in data"):

0 commit comments

Comments
 (0)