Skip to content

Commit 12d45d3

Browse files
committed
Merging
Signed-off-by: Adam Li <adam2392@gmail.com>
2 parents 31723a6 + f0d6a9c commit 12d45d3

35 files changed

+600
-207
lines changed

.circleci/config.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ jobs:
2424
- OPENBLAS_NUM_THREADS: 2
2525
- CONDA_ENV_NAME: testenv
2626
- LOCK_FILE: build_tools/circle/doc_min_dependencies_linux-64_conda.lock
27-
# Sphinx race condition in doc-min-dependencies is causing job to stall
28-
# Here we run the job serially
29-
- SPHINX_NUMJOBS: 1
3027
steps:
3128
- checkout
3229
- run: ./build_tools/circle/checkout_merge_commit.sh
@@ -61,8 +58,6 @@ jobs:
6158
- OPENBLAS_NUM_THREADS: 2
6259
- CONDA_ENV_NAME: testenv
6360
- LOCK_FILE: build_tools/circle/doc_linux-64_conda.lock
64-
# Disable sphinx parallelism to avoid EOFError or job stalling in CircleCI
65-
- SPHINX_NUMJOBS: 1
6661
steps:
6762
- checkout
6863
- run: ./build_tools/circle/checkout_merge_commit.sh

build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock

Lines changed: 35 additions & 40 deletions
Large diffs are not rendered by default.

build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ dependencies:
2020
- pytest-cov
2121
- coverage
2222
- ccache
23+
- pytorch=1.13
24+
- pytorch-cpu
25+
- array-api-compat

build_tools/update_environments_and_lock_files.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,15 @@ def remove_from(alist, to_remove):
8888
"folder": "build_tools/azure",
8989
"platform": "linux-64",
9090
"channel": "conda-forge",
91-
"conda_dependencies": common_dependencies + ["ccache"],
91+
"conda_dependencies": common_dependencies + [
92+
"ccache",
93+
"pytorch",
94+
"pytorch-cpu",
95+
"array-api-compat",
96+
],
9297
"package_constraints": {
9398
"blas": "[build=mkl]",
99+
"pytorch": "1.13",
94100
},
95101
},
96102
{

doc/Makefile

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,25 @@ SPHINXBUILD ?= sphinx-build
77
PAPER =
88
BUILDDIR = _build
99

10-
# Run sequential by default, unless SPHINX_NUMJOBS is set.
11-
SPHINX_NUMJOBS ?= 1
12-
1310
ifneq ($(EXAMPLES_PATTERN),)
1411
EXAMPLES_PATTERN_OPTS := -D sphinx_gallery_conf.filename_pattern="$(EXAMPLES_PATTERN)"
1512
endif
1613

14+
ifeq ($(CI), true)
15+
# On CircleCI using -j2 does not seem to speed up the html-noplot build
16+
SPHINX_NUMJOBS_NOPLOT_DEFAULT=1
17+
else ($(shell uname), Darwin)
18+
# Avoid stalling issues on MacOS
19+
SPHINX_NUMJOBS_NOPLOT_DEFAULT=1
20+
else
21+
SPHINX_NUMJOBS_NOPLOT_DEFAULT=auto
22+
endif
23+
1724
# Internal variables.
1825
PAPEROPT_a4 = -D latex_paper_size=a4
1926
PAPEROPT_letter = -D latex_paper_size=letter
2027
ALLSPHINXOPTS = -T -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\
21-
-j$(SPHINX_NUMJOBS) $(EXAMPLES_PATTERN_OPTS) .
28+
$(EXAMPLES_PATTERN_OPTS) .
2229

2330

2431
.PHONY: help clean html dirhtml ziphtml pickle json latex latexpdf changes linkcheck doctest optipng
@@ -44,19 +51,27 @@ clean:
4451
-rm -rf generated/*
4552
-rm -rf modules/generated/
4653

54+
# Default to SPHINX_NUMJOBS=1 for full documentation build. Using
55+
# SPHINX_NUMJOBS!=1 may actually slow down the build, or cause weird issues in
56+
# the CI (job stalling or EOFError), see
57+
# https://github.com/scikit-learn/scikit-learn/pull/25836 or
58+
# https://github.com/scikit-learn/scikit-learn/pull/25809
59+
html: SPHINX_NUMJOBS ?= 1
4760
html:
4861
# These two lines make the build a bit more lengthy, and the
4962
# the embedding of images more robust
5063
rm -rf $(BUILDDIR)/html/_images
5164
#rm -rf _build/doctrees/
52-
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html/stable
65+
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) -j$(SPHINX_NUMJOBS) $(BUILDDIR)/html/stable
5366
@echo
5467
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html/stable"
5568

56-
# rm $(BUILDDIR)/html/stable/index.html
57-
# mv $(BUILDDIR)/html/stable/fork_index.html $(BUILDDIR)/html/stable/index.html
69+
# Default to SPHINX_NUMJOBS=auto (except on MacOS and CI) since this makes
70+
# html-noplot build faster
71+
html-noplot: SPHINX_NUMJOBS ?= $(SPHINX_NUMJOBS_NOPLOT_DEFAULT)
5872
html-noplot:
59-
$(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html/stable
73+
$(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) -j$(SPHINX_NUMJOBS) \
74+
$(BUILDDIR)/html/stable
6075
@echo
6176
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html/stable."
6277

doc/jupyter-lite.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"jupyter-lite-schema-version": 0,
3+
"jupyter-config-data": {
4+
"litePluginSettings": {
5+
"@jupyterlite/pyodide-kernel-extension:kernel": {
6+
"pyodideUrl": "https://cdn.jsdelivr.net/pyodide/v0.23.1/full/pyodide.js"
7+
}
8+
}
9+
}
10+
}

doc/jupyter_lite_config.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"LiteBuildConfig": {
3+
"no_sourcemaps": true
4+
}
5+
}

doc/modules/array_api.rst

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Array API support (experimental)
1212

1313
The `Array API <https://data-apis.org/array-api/latest/>`_ specification defines
1414
a standard API for all array manipulation libraries with a NumPy-like API.
15+
Scikit-learn's Array API support requires
16+
`array-api-compat <https://github.com/data-apis/array-api-compat>`__ to be installed.
1517

1618
Some scikit-learn estimators that primarily rely on NumPy (as opposed to using
1719
Cython) to implement the algorithmic logic of their `fit`, `predict` or
@@ -23,8 +25,8 @@ At this stage, this support is **considered experimental** and must be enabled
2325
explicitly as explained in the following.
2426

2527
.. note::
26-
Currently, only `cupy.array_api` and `numpy.array_api` are known to work
27-
with scikit-learn's estimators.
28+
Currently, only `cupy.array_api`, `numpy.array_api`, `cupy`, and `PyTorch`
29+
are known to work with scikit-learn's estimators.
2830

2931
Example usage
3032
=============
@@ -36,11 +38,11 @@ Here is an example code snippet to demonstrate how to use `CuPy
3638
>>> from sklearn.datasets import make_classification
3739
>>> from sklearn import config_context
3840
>>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
39-
>>> import cupy.array_api as xp
41+
>>> import cupy
4042

4143
>>> X_np, y_np = make_classification(random_state=0)
42-
>>> X_cu = xp.asarray(X_np)
43-
>>> y_cu = xp.asarray(y_np)
44+
>>> X_cu = cupy.asarray(X_np)
45+
>>> y_cu = cupy.asarray(y_np)
4446
>>> X_cu.device
4547
<CUDA Device 0>
4648

@@ -57,12 +59,30 @@ GPU. We provide a experimental `_estimator_with_converted_arrays` utility that
5759
transfers an estimator attributes from Array API to a ndarray::
5860

5961
>>> from sklearn.utils._array_api import _estimator_with_converted_arrays
60-
>>> cupy_to_ndarray = lambda array : array._array.get()
62+
>>> cupy_to_ndarray = lambda array : array.get()
6163
>>> lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray)
6264
>>> X_trans = lda_np.transform(X_np)
6365
>>> type(X_trans)
6466
<class 'numpy.ndarray'>
6567

68+
PyTorch Support
69+
---------------
70+
71+
PyTorch Tensors are supported by setting `array_api_dispatch=True` and passing in
72+
the tensors directly::
73+
74+
>>> import torch
75+
>>> X_torch = torch.asarray(X_np, device="cuda", dtype=torch.float32)
76+
>>> y_torch = torch.asarray(y_np, device="cuda", dtype=torch.float32)
77+
78+
>>> with config_context(array_api_dispatch=True):
79+
... lda = LinearDiscriminantAnalysis()
80+
... X_trans = lda.fit_transform(X_torch, y_torch)
81+
>>> type(X_trans)
82+
<class 'torch.Tensor'>
83+
>>> X_trans.device.type
84+
'cuda'
85+
6686
.. _array_api_estimators:
6787

6888
Estimators with support for `Array API`-compatible inputs

doc/whats_new/v1.3.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,13 @@ Changelog
230230
:class:`decomposition.MiniBatchNMF` which can produce different results than previous
231231
versions. :pr:`25438` by :user:`Yotam Avidar-Constantini <yotamcons>`.
232232

233+
:mod:`sklearn.discriminant_analysis`
234+
....................................
235+
236+
- |Enhancement| :class:`discriminant_analysis.LinearDiscriminantAnalysis` now
237+
supports the `PyTorch <https://pytorch.org/>`__. See
238+
:ref:`array_api` for more details. :pr:`25956` by `Thomas Fan`_.
239+
233240
:mod:`sklearn.ensemble`
234241
.......................
235242

examples/model_selection/plot_permutation_tests_for_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
score_label = f"Score on original\ndata: {score_iris:.2f}\n(p-value: {pvalue_iris:.3f})"
9696
ax.text(0.7, 10, score_label, fontsize=12)
9797
ax.set_xlabel("Accuracy score")
98-
_ = ax.set_ylabel("Probability")
98+
_ = ax.set_ylabel("Probability density")
9999

100100
# %%
101101
# Random data
@@ -116,7 +116,7 @@
116116
score_label = f"Score on original\ndata: {score_rand:.2f}\n(p-value: {pvalue_rand:.3f})"
117117
ax.text(0.14, 7.5, score_label, fontsize=12)
118118
ax.set_xlabel("Accuracy score")
119-
ax.set_ylabel("Probability")
119+
ax.set_ylabel("Probability density")
120120
plt.show()
121121

122122
# %%

0 commit comments

Comments
 (0)