@@ -182,29 +182,89 @@ Tools
182
182
Coverage is expected to grow over time. Please follow the dedicated `meta-issue on GitHub
183
183
<https://github.com/scikit-learn/scikit-learn/issues/22352> `_ to track progress.
184
184
185
- Type of return values and fitted attributes
186
- -------------------------------------------
185
+ Input and output array type handling
186
+ ====================================
187
187
188
- When calling functions or methods with Array API compatible inputs, the
189
- convention is to return array values of the same array container type and
188
+ Estimators and scoring functions are able to accept input arrays
189
+ from different array libraries and/or devices. When a mixed set of input arrays is
190
+ passed, scikit-learn converts arrays as needed to make them all consistent.
191
+
192
+ For estimators, the rule is **"everything follows `X`" ** - mixed array inputs are
193
+ converted so that they all match the array library and device of `X `.
194
+ For scoring functions the rule is **"everything follows `y_pred`" ** - mixed array
195
+ inputs are converted so that they all match the array library and device of `y_pred `.
196
+
197
+ When a function or method has been called with array API compatible inputs, the
198
+ convention is to return arrays from the same array library and on the same
190
199
device as the input data.
191
200
192
- Similarly, when an estimator is fitted with Array API compatible inputs, the
193
- fitted attributes will be arrays from the same library as the input and stored
194
- on the same device. The `predict ` and `transform ` method subsequently expect
201
+ Estimators
202
+ ----------
203
+
204
+ When an estimator is fitted with an array API compatible `X `, all other
205
+ array inputs, including constructor arguments, (e.g., `y `, `sample_weight `)
206
+ will be converted to match the array library and device of `X `, if they do not already.
207
+ This behaviour enables switching from processing on the CPU to processing
208
+ on the GPU at any point within a pipeline.
209
+
210
+ This allows estimators to accept mixed input types, enabling `X ` to be moved
211
+ to a different device within a pipeline, without explicitly moving `y `.
212
+ Note that scikit-learn pipelines do not allow transformation of `y ` (to avoid
213
+ :ref: `leakage <data_leakage >`).
214
+
215
+ Take for example a pipeline where `X ` and `y ` both start on CPU, and go through
216
+ the following three steps:
217
+
218
+ * :class: `~sklearn.preprocessing.TargetEncoder `, which will transform categorial
219
+ `X ` but also requires `y `, meaning both `X ` and `y ` need to be on CPU.
220
+ * :class: `FunctionTransformer(func=partial(torch.asarray, device="cuda")) <sklearn.preprocessing.FunctionTransformer> `,
221
+ which moves `X ` to GPU, to improve performance in the next step.
222
+ * :class: `~sklearn.linear_model.Ridge `, whose performance can be improved when
223
+ passed arrays on a GPU, as they can handle large matrix operations very efficiently.
224
+
225
+ `X ` initially contains categorical string data (thus needs to be on CPU), which is
226
+ target encoded to numerical values in :class: `~sklearn.preprocessing.TargetEncoder `.
227
+ `X ` is then explicitly moved to GPU to improve the performance of
228
+ :class: `~sklearn.linear_model.Ridge `. `y ` cannot be transformed by the pipeline
229
+ (recall scikit-learn pipelines do not allow transformation of `y `) but as
230
+ :class: `~sklearn.linear_model.Ridge ` is able to accept mixed input types,
231
+ this is not a problem and the pipeline is able to be run.
232
+
233
+ The fitted attributes of an estimator fitted with an array API compatible `X `, will
234
+ be arrays from the same library as the input and stored on the same device.
235
+ The `predict ` and `transform ` method subsequently expect
195
236
inputs from the same array library and device as the data passed to the `fit `
196
237
method.
197
238
198
- Note however that scoring functions that return scalar values return Python
199
- scalars (typically a `float ` instance) instead of an array scalar value.
239
+ Scoring functions
240
+ -----------------
241
+
242
+ When an array API compatible `y_pred ` is passed to a scoring function,
243
+ all other array inputs (e.g., `y_true `, `sample_weight `) will be converted
244
+ to match the array library and device of `y_pred `, if they do not already.
245
+ This allows scoring functions to accept mixed input types, enabling them to be
246
+ used within a :term: `meta-estimator ` (or function that accepts estimators), with a
247
+ pipeline that moves input arrays between devices (e.g., CPU to GPU).
248
+
249
+ For example, to be able to use the pipeline described above within e.g.,
250
+ :func: `~sklearn.model_selection.cross_validate ` or
251
+ :class: `~sklearn.model_selection.GridSearchCV `, the scoring function internally
252
+ called needs to be able to accept mixed input types.
253
+
254
+ The output type of scoring functions depends on the number of output values.
255
+ When a scoring function returns a scalar value, it will return a Python
256
+ scalar (typically a `float ` instance) instead of an array scalar value.
257
+ For scoring functions that support :term: `multiclass ` or :term: `multioutput `,
258
+ an array from the same array library and device as `y_pred ` will be returned when
259
+ multiple values need to be output.
200
260
201
261
Common estimator checks
202
262
=======================
203
263
204
264
Add the `array_api_support ` tag to an estimator's set of tags to indicate that
205
- it supports the Array API. This will enable dedicated checks as part of the
265
+ it supports the array API. This will enable dedicated checks as part of the
206
266
common tests to verify that the estimators' results are the same when using
207
- vanilla NumPy and Array API inputs.
267
+ vanilla NumPy and array API inputs.
208
268
209
269
To run these checks you need to install
210
270
`array-api-strict <https://data-apis.org/array-api-strict/ >`_ in your
0 commit comments