Skip to content

Commit 20d33d5

Browse files
authored
DOC Clarify how mixed array input types handled in array api (scikit-learn#31452)
1 parent 6ccb204 commit 20d33d5

File tree

1 file changed

+71
-11
lines changed

1 file changed

+71
-11
lines changed

doc/modules/array_api.rst

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -182,29 +182,89 @@ Tools
182182
Coverage is expected to grow over time. Please follow the dedicated `meta-issue on GitHub
183183
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress.
184184

185-
Type of return values and fitted attributes
186-
-------------------------------------------
185+
Input and output array type handling
186+
====================================
187187

188-
When calling functions or methods with Array API compatible inputs, the
189-
convention is to return array values of the same array container type and
188+
Estimators and scoring functions are able to accept input arrays
189+
from different array libraries and/or devices. When a mixed set of input arrays is
190+
passed, scikit-learn converts arrays as needed to make them all consistent.
191+
192+
For estimators, the rule is **"everything follows `X`"** - mixed array inputs are
193+
converted so that they all match the array library and device of `X`.
194+
For scoring functions the rule is **"everything follows `y_pred`"** - mixed array
195+
inputs are converted so that they all match the array library and device of `y_pred`.
196+
197+
When a function or method has been called with array API compatible inputs, the
198+
convention is to return arrays from the same array library and on the same
190199
device as the input data.
191200

192-
Similarly, when an estimator is fitted with Array API compatible inputs, the
193-
fitted attributes will be arrays from the same library as the input and stored
194-
on the same device. The `predict` and `transform` method subsequently expect
201+
Estimators
202+
----------
203+
204+
When an estimator is fitted with an array API compatible `X`, all other
205+
array inputs, including constructor arguments, (e.g., `y`, `sample_weight`)
206+
will be converted to match the array library and device of `X`, if they do not already.
207+
This behaviour enables switching from processing on the CPU to processing
208+
on the GPU at any point within a pipeline.
209+
210+
This allows estimators to accept mixed input types, enabling `X` to be moved
211+
to a different device within a pipeline, without explicitly moving `y`.
212+
Note that scikit-learn pipelines do not allow transformation of `y` (to avoid
213+
:ref:`leakage <data_leakage>`).
214+
215+
Take for example a pipeline where `X` and `y` both start on CPU, and go through
216+
the following three steps:
217+
218+
* :class:`~sklearn.preprocessing.TargetEncoder`, which will transform categorial
219+
`X` but also requires `y`, meaning both `X` and `y` need to be on CPU.
220+
* :class:`FunctionTransformer(func=partial(torch.asarray, device="cuda")) <sklearn.preprocessing.FunctionTransformer>`,
221+
which moves `X` to GPU, to improve performance in the next step.
222+
* :class:`~sklearn.linear_model.Ridge`, whose performance can be improved when
223+
passed arrays on a GPU, as they can handle large matrix operations very efficiently.
224+
225+
`X` initially contains categorical string data (thus needs to be on CPU), which is
226+
target encoded to numerical values in :class:`~sklearn.preprocessing.TargetEncoder`.
227+
`X` is then explicitly moved to GPU to improve the performance of
228+
:class:`~sklearn.linear_model.Ridge`. `y` cannot be transformed by the pipeline
229+
(recall scikit-learn pipelines do not allow transformation of `y`) but as
230+
:class:`~sklearn.linear_model.Ridge` is able to accept mixed input types,
231+
this is not a problem and the pipeline is able to be run.
232+
233+
The fitted attributes of an estimator fitted with an array API compatible `X`, will
234+
be arrays from the same library as the input and stored on the same device.
235+
The `predict` and `transform` method subsequently expect
195236
inputs from the same array library and device as the data passed to the `fit`
196237
method.
197238

198-
Note however that scoring functions that return scalar values return Python
199-
scalars (typically a `float` instance) instead of an array scalar value.
239+
Scoring functions
240+
-----------------
241+
242+
When an array API compatible `y_pred` is passed to a scoring function,
243+
all other array inputs (e.g., `y_true`, `sample_weight`) will be converted
244+
to match the array library and device of `y_pred`, if they do not already.
245+
This allows scoring functions to accept mixed input types, enabling them to be
246+
used within a :term:`meta-estimator` (or function that accepts estimators), with a
247+
pipeline that moves input arrays between devices (e.g., CPU to GPU).
248+
249+
For example, to be able to use the pipeline described above within e.g.,
250+
:func:`~sklearn.model_selection.cross_validate` or
251+
:class:`~sklearn.model_selection.GridSearchCV`, the scoring function internally
252+
called needs to be able to accept mixed input types.
253+
254+
The output type of scoring functions depends on the number of output values.
255+
When a scoring function returns a scalar value, it will return a Python
256+
scalar (typically a `float` instance) instead of an array scalar value.
257+
For scoring functions that support :term:`multiclass` or :term:`multioutput`,
258+
an array from the same array library and device as `y_pred` will be returned when
259+
multiple values need to be output.
200260

201261
Common estimator checks
202262
=======================
203263

204264
Add the `array_api_support` tag to an estimator's set of tags to indicate that
205-
it supports the Array API. This will enable dedicated checks as part of the
265+
it supports the array API. This will enable dedicated checks as part of the
206266
common tests to verify that the estimators' results are the same when using
207-
vanilla NumPy and Array API inputs.
267+
vanilla NumPy and array API inputs.
208268

209269
To run these checks you need to install
210270
`array-api-strict <https://data-apis.org/array-api-strict/>`_ in your

0 commit comments

Comments
 (0)