Skip to content

Commit 04575f5

Browse files
authored
MAINT Param validation: apply skip nested validation to all functions (scikit-learn#26495)
1 parent 5e8d8cb commit 04575f5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+411
-210
lines changed

sklearn/calibration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,8 @@ def predict(self, T):
955955
"pos_label": [Real, str, "boolean", None],
956956
"n_bins": [Interval(Integral, 1, None, closed="left")],
957957
"strategy": [StrOptions({"uniform", "quantile"})],
958-
}
958+
},
959+
prefer_skip_nested_validation=True,
959960
)
960961
def calibration_curve(
961962
y_true,

sklearn/cluster/_affinity_propagation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ def _affinity_propagation(
185185
{
186186
"S": ["array-like"],
187187
"return_n_iter": ["boolean"],
188-
}
188+
},
189+
prefer_skip_nested_validation=False,
189190
)
190191
def affinity_propagation(
191192
S,

sklearn/cluster/_agglomerative.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ def _single_linkage_tree(
185185
"connectivity": ["array-like", "sparse matrix", None],
186186
"n_clusters": [Interval(Integral, 1, None, closed="left"), None],
187187
"return_distance": ["boolean"],
188-
}
188+
},
189+
prefer_skip_nested_validation=True,
189190
)
190191
def ward_tree(X, *, connectivity=None, n_clusters=None, return_distance=False):
191192
"""Ward clustering based on a Feature matrix.

sklearn/cluster/_kmeans.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@
6666
"x_squared_norms": ["array-like", None],
6767
"random_state": ["random_state"],
6868
"n_local_trials": [Interval(Integral, 1, None, closed="left"), None],
69-
}
69+
},
70+
prefer_skip_nested_validation=True,
7071
)
7172
def kmeans_plusplus(
7273
X,
@@ -293,24 +294,10 @@ def _tolerance(X, tol):
293294
@validate_params(
294295
{
295296
"X": ["array-like", "sparse matrix"],
296-
"n_clusters": [Interval(Integral, 1, None, closed="left")],
297297
"sample_weight": ["array-like", None],
298-
"init": [StrOptions({"k-means++", "random"}), callable, "array-like"],
299-
"n_init": [
300-
StrOptions({"auto"}),
301-
Hidden(StrOptions({"warn"})),
302-
Interval(Integral, 1, None, closed="left"),
303-
],
304-
"max_iter": [Interval(Integral, 1, None, closed="left")],
305-
"verbose": [Interval(Integral, 0, None, closed="left"), bool],
306-
"tol": [Interval(Real, 0, None, closed="left")],
307-
"random_state": ["random_state"],
308-
"copy_x": [bool],
309-
"algorithm": [
310-
StrOptions({"lloyd", "elkan", "auto", "full"}, deprecated={"auto", "full"})
311-
],
312298
"return_n_iter": [bool],
313-
}
299+
},
300+
prefer_skip_nested_validation=False,
314301
)
315302
def k_means(
316303
X,

sklearn/cluster/_mean_shift.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
"n_samples": [Interval(Integral, 1, None, closed="left"), None],
3838
"random_state": ["random_state"],
3939
"n_jobs": [Integral, None],
40-
}
40+
},
41+
prefer_skip_nested_validation=True,
4142
)
4243
def estimate_bandwidth(X, *, quantile=0.3, n_samples=None, random_state=0, n_jobs=None):
4344
"""Estimate the bandwidth to use with the mean-shift algorithm.
@@ -120,7 +121,10 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter):
120121
return tuple(my_mean), len(points_within), completed_iterations
121122

122123

123-
@validate_params({"X": ["array-like"]})
124+
@validate_params(
125+
{"X": ["array-like"]},
126+
prefer_skip_nested_validation=False,
127+
)
124128
def mean_shift(
125129
X,
126130
*,

sklearn/cluster/_optics.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,8 @@ def _compute_core_distances_(X, neighbors, min_samples, working_memory):
447447
"algorithm": [StrOptions({"auto", "brute", "ball_tree", "kd_tree"})],
448448
"leaf_size": [Interval(Integral, 1, None, closed="left")],
449449
"n_jobs": [Integral, None],
450-
}
450+
},
451+
prefer_skip_nested_validation=False, # metric is not validated yet
451452
)
452453
def compute_optics_graph(
453454
X, *, min_samples, max_eps, metric, p, metric_params, algorithm, leaf_size, n_jobs
@@ -686,7 +687,8 @@ def _set_reach_dist(
686687
"core_distances": [np.ndarray],
687688
"ordering": [np.ndarray],
688689
"eps": [Interval(Real, 0, None, closed="both")],
689-
}
690+
},
691+
prefer_skip_nested_validation=True,
690692
)
691693
def cluster_optics_dbscan(*, reachability, core_distances, ordering, eps):
692694
"""Perform DBSCAN extraction for an arbitrary epsilon.
@@ -742,7 +744,8 @@ def cluster_optics_dbscan(*, reachability, core_distances, ordering, eps):
742744
],
743745
"xi": [Interval(Real, 0, 1, closed="both")],
744746
"predecessor_correction": ["boolean"],
745-
}
747+
},
748+
prefer_skip_nested_validation=True,
746749
)
747750
def cluster_optics_xi(
748751
*,

sklearn/cluster/_spectral.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,10 @@ def discretize(
189189
return labels
190190

191191

192-
@validate_params({"affinity": ["array-like", "sparse matrix"]})
192+
@validate_params(
193+
{"affinity": ["array-like", "sparse matrix"]},
194+
prefer_skip_nested_validation=False,
195+
)
193196
def spectral_clustering(
194197
affinity,
195198
*,

sklearn/covariance/_empirical_covariance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def log_likelihood(emp_cov, precision):
5454
{
5555
"X": ["array-like"],
5656
"assume_centered": ["boolean"],
57-
}
57+
},
58+
prefer_skip_nested_validation=True,
5859
)
5960
def empirical_covariance(X, *, assume_centered=False):
6061
"""Compute the Maximum likelihood covariance estimator.

sklearn/covariance/_graph_lasso.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ def alpha_max(emp_cov):
217217
"cov_init": ["array-like", None],
218218
"return_costs": ["boolean"],
219219
"return_n_iter": ["boolean"],
220-
}
220+
},
221+
prefer_skip_nested_validation=False,
221222
)
222223
def graphical_lasso(
223224
emp_cov,

sklearn/covariance/_shrunk_covariance.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def _oas(X, *, assume_centered=False):
105105
{
106106
"emp_cov": ["array-like"],
107107
"shrinkage": [Interval(Real, 0, 1, closed="both")],
108-
}
108+
},
109+
prefer_skip_nested_validation=True,
109110
)
110111
def shrunk_covariance(emp_cov, shrinkage=0.1):
111112
"""Calculate a covariance matrix shrunk on the diagonal.
@@ -279,7 +280,8 @@ def fit(self, X, y=None):
279280
"X": ["array-like"],
280281
"assume_centered": ["boolean"],
281282
"block_size": [Interval(Integral, 1, None, closed="left")],
282-
}
283+
},
284+
prefer_skip_nested_validation=True,
283285
)
284286
def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000):
285287
"""Estimate the shrunk Ledoit-Wolf covariance matrix.
@@ -376,7 +378,10 @@ def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000):
376378
return shrinkage
377379

378380

379-
@validate_params({"X": ["array-like"]})
381+
@validate_params(
382+
{"X": ["array-like"]},
383+
prefer_skip_nested_validation=False,
384+
)
380385
def ledoit_wolf(X, *, assume_centered=False, block_size=1000):
381386
"""Estimate the shrunk Ledoit-Wolf covariance matrix.
382387
@@ -569,7 +574,10 @@ def fit(self, X, y=None):
569574

570575

571576
# OAS estimator
572-
@validate_params({"X": ["array-like"]})
577+
@validate_params(
578+
{"X": ["array-like"]},
579+
prefer_skip_nested_validation=False,
580+
)
573581
def oas(X, *, assume_centered=False):
574582
"""Estimate covariance with the Oracle Approximating Shrinkage as proposed in [1]_.
575583

0 commit comments

Comments
 (0)