Skip to content

Commit e490da9

Browse files
MAINT Parameters validation for sklearn.utils.safe_mask (scikit-learn#26131)
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent 76d9be2 commit e490da9

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def _check_function_param_validation(
285285
"sklearn.utils.gen_batches",
286286
"sklearn.utils.graph.single_source_shortest_path_length",
287287
"sklearn.utils.resample",
288+
"sklearn.utils.safe_mask",
288289
]
289290

290291

sklearn/utils/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,13 @@ def _in_unstable_openblas_configuration():
112112
return False
113113

114114

115+
@validate_params(
116+
{
117+
"X": ["array-like", "sparse matrix"],
118+
"mask": ["array-like"],
119+
},
120+
prefer_skip_nested_validation=True,
121+
)
115122
def safe_mask(X, mask):
116123
"""Return a mask which is safe to use on X.
117124
@@ -120,7 +127,7 @@ def safe_mask(X, mask):
120127
X : {array-like, sparse matrix}
121128
Data on which to apply mask.
122129
123-
mask : ndarray
130+
mask : array-like
124131
Mask to be used on X.
125132
126133
Returns

0 commit comments

Comments
 (0)