Skip to content

Commit 2a548da

Browse files
FIX make dataset fetchers accept os.Pathlike for data_home (scikit-learn#27468)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 8477d63 commit 2a548da

File tree

12 files changed

+51
-31
lines changed

12 files changed

+51
-31
lines changed

doc/whats_new/v1.4.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ Changelog
182182
which returns a dense numpy ndarray as before.
183183
:pr:`27438` by :user:`Yao Xiao <Charlie-XIAO>`.
184184

185+
- |Fix| All dataset fetchers now accept `data_home` as any object that implements
186+
the :class:`os.PathLike` interface, for instance, :class:`pathlib.Path`.
187+
:pr:`27468` by :user:`Yao Xiao <Charlie-XIAO>`.
188+
185189
:mod:`sklearn.decomposition`
186190
............................
187191

sklearn/datasets/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_data_home(data_home=None) -> str:
5757
----------
5858
data_home : str or path-like, default=None
5959
The path to scikit-learn data directory. If `None`, the default path
60-
is `~/sklearn_learn_data`.
60+
is `~/scikit_learn_data`.
6161
6262
Returns
6363
-------
@@ -84,7 +84,7 @@ def clear_data_home(data_home=None):
8484
----------
8585
data_home : str or path-like, default=None
8686
The path to scikit-learn data directory. If `None`, the default path
87-
is `~/sklearn_learn_data`.
87+
is `~/scikit_learn_data`.
8888
"""
8989
data_home = get_data_home(data_home)
9090
shutil.rmtree(data_home)

sklearn/datasets/_california_housing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import logging
2525
import tarfile
26-
from os import makedirs, remove
26+
from os import PathLike, makedirs, remove
2727
from os.path import exists
2828

2929
import joblib
@@ -53,7 +53,7 @@
5353

5454
@validate_params(
5555
{
56-
"data_home": [str, None],
56+
"data_home": [str, PathLike, None],
5757
"download_if_missing": ["boolean"],
5858
"return_X_y": ["boolean"],
5959
"as_frame": ["boolean"],
@@ -76,7 +76,7 @@ def fetch_california_housing(
7676
7777
Parameters
7878
----------
79-
data_home : str, default=None
79+
data_home : str or path-like, default=None
8080
Specify another download and cache folder for the datasets. By default
8181
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
8282

sklearn/datasets/_covtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666
@validate_params(
6767
{
68-
"data_home": [str, None],
68+
"data_home": [str, os.PathLike, None],
6969
"download_if_missing": ["boolean"],
7070
"random_state": ["random_state"],
7171
"shuffle": ["boolean"],
@@ -98,7 +98,7 @@ def fetch_covtype(
9898
9999
Parameters
100100
----------
101-
data_home : str, default=None
101+
data_home : str or path-like, default=None
102102
Specify another download and cache folder for the datasets. By default
103103
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
104104

sklearn/datasets/_kddcup99.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
@validate_params(
5151
{
5252
"subset": [StrOptions({"SA", "SF", "http", "smtp"}), None],
53-
"data_home": [str, None],
53+
"data_home": [str, os.PathLike, None],
5454
"shuffle": ["boolean"],
5555
"random_state": ["random_state"],
5656
"percent10": ["boolean"],
@@ -92,7 +92,7 @@ def fetch_kddcup99(
9292
To return the corresponding classical subsets of kddcup 99.
9393
If None, return the entire kddcup 99 dataset.
9494
95-
data_home : str, default=None
95+
data_home : str or path-like, default=None
9696
Specify another download and cache folder for the datasets. By default
9797
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
9898

sklearn/datasets/_lfw.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import logging
1212
from numbers import Integral, Real
13-
from os import listdir, makedirs, remove
13+
from os import PathLike, listdir, makedirs, remove
1414
from os.path import exists, isdir, join
1515

1616
import numpy as np
@@ -234,7 +234,7 @@ def _fetch_lfw_people(
234234

235235
@validate_params(
236236
{
237-
"data_home": [str, None],
237+
"data_home": [str, PathLike, None],
238238
"funneled": ["boolean"],
239239
"resize": [Interval(Real, 0, None, closed="neither"), None],
240240
"min_faces_per_person": [Interval(Integral, 0, None, closed="left"), None],
@@ -272,7 +272,7 @@ def fetch_lfw_people(
272272
273273
Parameters
274274
----------
275-
data_home : str, default=None
275+
data_home : str or path-like, default=None
276276
Specify another download and cache folder for the datasets. By default
277277
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
278278
@@ -431,7 +431,7 @@ def _fetch_lfw_pairs(
431431
@validate_params(
432432
{
433433
"subset": [StrOptions({"train", "test", "10_folds"})],
434-
"data_home": [str, None],
434+
"data_home": [str, PathLike, None],
435435
"funneled": ["boolean"],
436436
"resize": [Interval(Real, 0, None, closed="neither"), None],
437437
"color": ["boolean"],
@@ -480,7 +480,7 @@ def fetch_lfw_pairs(
480480
official evaluation set that is meant to be used with a 10-folds
481481
cross validation.
482482
483-
data_home : str, default=None
483+
data_home : str or path-like, default=None
484484
Specify another download and cache folder for the datasets. By
485485
default all scikit-learn data is stored in '~/scikit_learn_data'
486486
subfolders.

sklearn/datasets/_olivetti_faces.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# Copyright (c) 2011 David Warde-Farley <wardefar at iro dot umontreal dot ca>
1414
# License: BSD 3 clause
1515

16-
from os import makedirs, remove
16+
from os import PathLike, makedirs, remove
1717
from os.path import exists
1818

1919
import joblib
@@ -36,7 +36,7 @@
3636

3737
@validate_params(
3838
{
39-
"data_home": [str, None],
39+
"data_home": [str, PathLike, None],
4040
"shuffle": ["boolean"],
4141
"random_state": ["random_state"],
4242
"download_if_missing": ["boolean"],
@@ -67,7 +67,7 @@ def fetch_olivetti_faces(
6767
6868
Parameters
6969
----------
70-
data_home : str, default=None
70+
data_home : str or path-like, default=None
7171
Specify another download and cache folder for the datasets. By default
7272
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
7373

sklearn/datasets/_openml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def _valid_data_column_names(features_list, target_columns):
749749
"name": [str, None],
750750
"version": [Interval(Integral, 1, None, closed="left"), StrOptions({"active"})],
751751
"data_id": [Interval(Integral, 1, None, closed="left"), None],
752-
"data_home": [str, None],
752+
"data_home": [str, os.PathLike, None],
753753
"target_column": [str, list, None],
754754
"cache": [bool],
755755
"return_X_y": [bool],
@@ -769,7 +769,7 @@ def fetch_openml(
769769
*,
770770
version: Union[str, int] = "active",
771771
data_id: Optional[int] = None,
772-
data_home: Optional[str] = None,
772+
data_home: Optional[Union[str, os.PathLike]] = None,
773773
target_column: Optional[Union[str, List]] = "default-target",
774774
cache: bool = True,
775775
return_X_y: bool = False,
@@ -815,7 +815,7 @@ def fetch_openml(
815815
dataset. If data_id is not given, name (and potential version) are
816816
used to obtain a dataset.
817817
818-
data_home : str, default=None
818+
data_home : str or path-like, default=None
819819
Specify another download and cache folder for the data sets. By default
820820
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
821821

sklearn/datasets/_rcv1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import logging
1212
from gzip import GzipFile
13-
from os import makedirs, remove
13+
from os import PathLike, makedirs, remove
1414
from os.path import exists, join
1515

1616
import joblib
@@ -74,7 +74,7 @@
7474

7575
@validate_params(
7676
{
77-
"data_home": [str, None],
77+
"data_home": [str, PathLike, None],
7878
"subset": [StrOptions({"train", "test", "all"})],
7979
"download_if_missing": ["boolean"],
8080
"random_state": ["random_state"],
@@ -111,7 +111,7 @@ def fetch_rcv1(
111111
112112
Parameters
113113
----------
114-
data_home : str, default=None
114+
data_home : str or path-like, default=None
115115
Specify another download and cache folder for the datasets. By default
116116
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
117117

sklearn/datasets/_species_distributions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
import logging
4141
from io import BytesIO
42-
from os import makedirs, remove
42+
from os import PathLike, makedirs, remove
4343
from os.path import exists
4444

4545
import joblib
@@ -136,7 +136,7 @@ def construct_grids(batch):
136136

137137

138138
@validate_params(
139-
{"data_home": [str, None], "download_if_missing": ["boolean"]},
139+
{"data_home": [str, PathLike, None], "download_if_missing": ["boolean"]},
140140
prefer_skip_nested_validation=True,
141141
)
142142
def fetch_species_distributions(*, data_home=None, download_if_missing=True):
@@ -146,7 +146,7 @@ def fetch_species_distributions(*, data_home=None, download_if_missing=True):
146146
147147
Parameters
148148
----------
149-
data_home : str, default=None
149+
data_home : str or path-like, default=None
150150
Specify another download and cache folder for the datasets. By default
151151
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
152152

0 commit comments

Comments
 (0)