8000 Merge branch 'main' into auto-update-lock-files-main · scikit-learn/scikit-learn@f9848e4 · GitHub
[go: up one dir, main page]

Skip to content

Commit f9848e4

Browse files
authored
Merge branch 'main' into auto-update-lock-files-main
2 parents 98c6f18 + 74d1307 commit f9848e4

24 files changed

+641
-102
lines changed

build_tools/azure/pypy3_linux-64_conda.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda#8b9
3232
https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.1.0-hd590300_1.conda#f07002e225d7a60a694d42a7bf5ff53f
3333
https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.1.0-hd590300_1.conda#5fc11c6020d421960607d821310fcd4d
3434
https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_5.conda#e73e9cfd1191783392131e6238bdb3e9
35-
https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.42-h2797004_0.conda#d67729828dc6ff7ba44a61062ad79880
35+
https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.43-h2797004_0.conda#009981dd9cfcaa4dbfa25ffaed86bcae A935
3636
https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.45.1-h2797004_0.conda#fc4ccadfbf6d4784de88c41704792562
3737
https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda#33277193f5b92bad9fdd230eb700929c
3838
https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda#47d31b792659ce70f470b5c82fdfb7a4
@@ -74,19 +74,19 @@ https://conda.anaconda.org/conda-forge/linux-64/pillow-10.2.0-py39hcf8a34e_0.con
7474
https://conda.anaconda.org/conda-forge/noarch/pluggy-1.4.0-pyhd8ed1ab_0.conda#139e9feb65187e916162917bb2484976
7575
https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb
7676
https://conda.anaconda.org/conda-forge/noarch/pypy-7.3.15-1_pypy39.conda#a418a6c16bd6f7ed56b92194214791a0
77-
https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.0-pyhd8ed1ab_0.conda#6df2be294365eca602cabb4f04a6efe2
77+
https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.1-pyhd8ed1ab_0.conda#576de899521b7d43674ba3ef6eae9142
7878
https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2#e5f25f8dbc060e9a8d912e432202afc2
7979
https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.3.0-pyhc1e730c_0.conda#698d2d2b621640bddb9191f132967c9f
8080
https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5844808ffab9ebdb694585b50ba02a96
81-
https://conda.anaconda.org/conda-forge/linux-64/tornado-6.3.3-py39hf860d4a_1.conda#ed9f2e116805d111f969b78e71203eef
81+
https://conda.anaconda.org/conda-forge/linux-64/tornado-6.4-py39hf860d4a_0.conda#e7fded713fb466e1e0670afce1761b47
8282
https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-15.1.0-py39hf860d4a_0.conda#f699157518d28d00c87542b4ec1273be
8383
https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a
8484
https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-21_linux64_openblas.conda#77cefbfb4d47ba8cafef8e3f768a4538
8585
https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py39ha90811c_0.conda#f3b2afc64bf0cbe901a9b00d44611c61
8686
https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.49.0-py39hf860d4a_0.conda#fa0d38d44f69d5c8ca476beb24fb456e
8787
https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.1.1-pyhd8ed1ab_0.conda#3d5fa25cf42f3f32a12b2d874ace8574
8888
https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc
89-
https://conda.anaconda.org/conda-forge/noarch/pytest-8.0.0-pyhd8ed1ab_0.conda#5ba1cc5b924226349d4a49fb547b7579
89+
https://conda.anaconda.org/conda-forge/noarch/pytest-8.0.2-pyhd8ed1ab_0.conda#40bd3ef942b9642a3eb20b0bbf92469b
9090
https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984
9191
https://conda.anaconda.org/conda-forge/linux-64/scipy-1.12.0-py39h6dedee3_2.conda#6c5d74bac41838f4377dfd45085e1fec
9292
https://conda.anaconda.org/conda-forge/linux-64/blas-2.121-openblas.conda#4a279792fd8861a15705516a52872eb6

build_tools/cirrus/pymin_conda_forge_linux-aarch64_conda.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ https://conda.anaconda.org/conda-forge/linux-aarch64/xz-5.2.6-h9cdd2b7_0.tar.bz2
3030
https://conda.anaconda.org/conda-forge/linux-aarch64/libbrotlidec-1.1.0-h31becfc_1.conda#8db7cff89510bec0b863a0a8ee6a7bce
3131
https://conda.anaconda.org/conda-forge/linux-aarch64/libbrotlienc-1.1.0-h31becfc_1.conda#ad3d3a826b5848d99936e4466ebbaa26
3232
https://conda.anaconda.org/conda-forge/linux-aarch64/libgfortran-ng-13.2.0-he9431aa_5.conda#fab7c6a8c84492e18cbe578820e97a56
33-
https://conda.anaconda.org/conda-forge/linux-aarch64/libpng-1.6.42-h194ca79_0.conda#b8ff00cc9a5184726baea61244f8bec3
33+
https://conda.anaconda.org/conda-forge/linux-aarch64/libpng-1.6.43-h194ca79_0.conda#1123e504d9254dd9494267ab9aba95f0
3434
https://conda.anaconda.org/conda-forge/linux-aarch64/libsqlite-3.45.1-h194ca79_0.conda#4190198deb1ed253eb938f6a6d92ff4f
3535
https://conda.anaconda.org/conda-forge/linux-aarch64/libxcb-1.15-h2a766a3_0.conda#eb3d8c8170e3d03f2564ed2024aa00c8
3636
https://conda.anaconda.org/conda-forge/linux-aarch64/readline-8.2-h8fc344f_1.conda#105eb1e16bf83bfb2eb380a48032b655
@@ -61,11 +61,11 @@ https://conda.anaconda.org/conda-forge/linux-aarch64/openjpeg-2.5.0-h0d9d63b_3.c
6161
https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda#79002079284aa895f883c6b7f3f88fd6
6262
https://conda.anaconda.org/conda-forge/noarch/pluggy-1.4.0-pyhd8ed1ab_0.conda#139e9feb65187e916162917bb2484976
6363
https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.1-pyhd8ed1ab_0.conda#176f7d56f0cfe9008bdf1bccd7de02fb
64-
https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.0-pyhd8ed1ab_0.conda#6df2be294365eca602cabb4f04a6efe2
64+
https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.1-pyhd8ed1ab_0.conda#576de899521b7d43674ba3ef6eae9142
6565
https://conda.anaconda.org/conda-forge/noarch/six-1.16.0-pyh6c4a22f_0.tar.bz2#e5f25f8dbc060e9a8d912e432202afc2
6666
https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.3.0-pyhc1e730c_0.conda#698d2d2b621640bddb9191f132967c9f
6767
https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2#5844808ffab9ebdb694585b50ba02a96
68-
https://conda.anaconda.org/conda-forge/linux-aarch64/tornado-6.3.3-py39h7cc1d5f_1.conda#c383c279123694d7a586ec47320d1cb1
68+
https://conda.anaconda.org/conda-forge/linux-aarch64/tornado-6.4-py39h7cc1d5f_0.conda#2c06a653ebfa389c18aea2d8f338df3b
6969
https://conda.anaconda.org/conda-forge/linux-aarch64/unicodedata2-15.1.0-py39h898b7ef_0.conda#8c072c9329aeea97a46005625267a851
7070
https://conda.anaconda.org/conda-forge/noarch/wheel-0.42.0-pyhd8ed1ab_0.conda#1cdea58981c5cbc17b51973bcaddcea7
7171
https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda#2e4d6bc0b14e10f895fc6791a7d9b26a
@@ -76,7 +76,7 @@ https://conda.anaconda.org/conda-forge/linux-aarch64/libcblas-3.9.0-21_linuxaarc
7676
https://conda.anaconda.org/conda-forge/linux-aarch64/liblapack-3.9.0-21_linuxaarch64_openblas.conda#ab08b651e3630c20d3032e59859f34f7
7777
https://conda.anaconda.org/conda-forge/linux-aarch64/pillow-10.2.0-py39h8ce38d7_0.conda#cf4745fb7f7cb5d0b90c476116c7d8ac
7878
https://conda.anaconda.org/conda-forge/noarch/pip-24.0-pyhd8ed1ab_0.conda#f586ac1e56c8638b64f9c8122a7b8a67
79-
https://conda.anaconda.org/conda-forge/noarch/pytest-8.0.0-pyhd8ed1ab_0.conda#5ba1cc5b924226349d4a49fb547b7579
79+
https://conda.anaconda.org/conda-forge/noarch/pytest-8.0.2-pyhd8ed1ab_0.conda#40bd3ef942b9642a3eb20b0bbf92469b
8080
https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.2-pyhd8ed1ab_0.tar.bz2#dd999d1cc9f79e67dbb855c8924c7984
8181
https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.1.1-pyhd8ed1ab_0.conda#d04bd1b5bed9177dd7c3cef15e2b6710
8282
https://conda.anaconda.org/conda-forge/linux-aarch64/liblapacke-3.9.0-21_linuxaarch64_openblas.conda#be00a60ef5d88de133a28cb1fb6e0b31

doc/metadata_routing.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ Meta-estimators and functions supporting metadata routing:
287287
- :class:`sklearn.linear_model.LogisticRegressionCV`
288288
- :class:`sklearn.linear_model.MultiTaskElasticNetCV`
289289
- :class:`sklearn.linear_model.MultiTaskLassoCV`
290+
- :class:`sklearn.linear_model.RANSACRegressor`
290291
- :class:`sklearn.model_selection.GridSearchCV`
291292
- :class:`sklearn.model_selection.HalvingGridSearchCV`
292293
- :class:`sklearn.model_selection.HalvingRandomSearchCV`
@@ -315,6 +316,7 @@ Meta-estimators and tools not supporting metadata routing yet:
315316
- :class:`sklearn.feature_selection.RFE`
316317
- :class:`sklearn.feature_selection.RFECV`
317318
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
319+
- :class:`sklearn.impute.IterativeImputer`
318320
- :class:`sklearn.linear_model.RANSACRegressor`
319321
- :class:`sklearn.linear_model.RidgeClassifierCV`
320322
- :class:`sklearn.linear_model.RidgeCV`

doc/modules/linear_model.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,10 +1536,10 @@ Each iteration performs the following steps:
15361536

15371537
1. Select ``min_samples`` random samples from the original data and check
15381538
whether the set of data is valid (see ``is_data_valid``).
1539-
2. Fit a model to the random subset (``base_estimator.fit``) and check
1539+
2. Fit a model to the random subset (``estimator.fit``) and check
15401540
whether the estimated model is valid (see ``is_model_valid``).
15411541
3. Classify all data as inliers or outliers by calculating the residuals
1542-
to the estimated model (``base_estimator.predict(X) - y``) - all data
1542+
to the estimated model (``estimator.predict(X) - y``) - all data
15431543
samples with absolute residuals smaller than or equal to the
15441544
``residual_threshold`` are considered as inliers.
15451545
4. Save fitted model as best model if number of inlier samples is

doc/whats_new/v1.5.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ more details.
4848
via their `fit` methods.
4949
:pr:`28432` by :user:`Adam Li <adam2392>` and :user:`Benjamin Bossan <BenjaminBossan>`.
5050

51+
Metadata Routing
52+
----------------
53+
54+
The following models now support metadata routing in one or more or their
55+
methods. Refer to the :ref:`Metadata Routing User Guide <metadata_routing>` for
56+
more details.
57+
58+
- |Feature| :class:`linear_model.RANSACRegressor` now supports metadata routing
59+
in its ``fit``, ``score`` and ``predict`` methods and route metadata to its
60+
underlying estimator's' ``fit``, ``score`` and ``predict`` methods.
61+
:pr:`28261` by :user:`Stefanie Senger <StefanieSenger>`.
62+
5163
- |Feature| :class:`ensemble.VotingClassifier` and
5264
:class:`ensemble.VotingRegressor` now support metadata routing and pass
5365
``**fit_params`` to the underlying estimators via their `fit` methods.
@@ -67,6 +79,12 @@ Changelog
6779
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
6880
where 123455 is the *pull request* number, not the issue number.
6981
82+
:mod:`sklearn.cluster`
83+
......................
84+
85+
- |FIX| Create copy of precomputed sparse matrix within the `fit` method of
86+
:class:`~cluster.OPTICS` to avoid in-place modification of the sparse matrix.
87+
:pr:`28491` by :user:`Thanh Lam Dang <lamdang2k>`.
7088

7189
:mod:`sklearn.compose`
7290
......................
@@ -78,6 +96,23 @@ Changelog
7896
only `inverse_func` is provided without `func` (that would default to identity) being
7997
explicitly set as well. :pr:`28483` by :user:`Stefanie Senger <StefanieSenger>`.
8098

99+
:mod:`sklearn.datasets`
100+
.......................
101+
102+
- |Enhancement| Adds optional arguments `n_retries` and `delay` to functions
103+
:func:`datasets.fetch_20newsgroups`,
104+
:func:`datasets.fetch_20newsgroups_vectorized`,
105+
:func:`datasets.fetch_california_housing`,
106+
:func:`datasets.fetch_covtype`,
107+
:func:`datasets.fetch_kddcup99`,
108+
:func:`datasets.fetch_lfw_pairs`,
109+
:func:`datasets.fetch_lfw_people`,
110+
:func:`datasets.fetch_olivetti_faces`,
111+
:func:`datasets.fetch_rcv1`,
112< 1CF5 /code>+
and :func:`datasets.fetch_species_distributions`.
113+
By default, the functions will retry up to 3 times in case of network failures.
114+
:pr:`28160` by :user:`Zhehao Liu <MaxwellLZH>` and :user:`Filip Karlo Došilović <fkdosilovic>`.
115+
81116
:mod:`sklearn.dummy`
82117
....................
83118

sklearn/cluster/_optics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def fit(self, X, y=None):
333333

334334
X = self._validate_data(X, dtype=dtype, accept_sparse="csr")
335335
if self.metric == "precomputed" and issparse(X):
336+
X = X.copy() # copy to avoid in-place modification
336337
with warnings.catch_warnings():
337338
warnings.simplefilter("ignore", SparseEfficiencyWarning)
338339
# Set each diagonal to an explicit value so each point is its

sklearn/cluster/tests/test_optics.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,27 @@ def test_precomputed_dists(global_dtype, csr_container):
816816
assert_array_equal(clust1.labels_, clust2.labels_)
817817

818818

819+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
820+
def test_optics_input_not_modified_precomputed_sparse_nodiag(csr_container):
821+
"""Check that we don't modify in-place the pre-computed sparse matrix.
822+
Non-regression test for:
823+
https://github.com/scikit-learn/scikit-learn/issues/27508
824+
"""
825+
X = np.random.RandomState(0).rand(6, 6)
826+
# Add zeros on the diagonal that will be implicit when creating
827+
# the sparse matrix. If `X` is modified in-place, the zeros from
828+
# the diagonal will be made explicit.
829+
np.fill_diagonal(X, 0)
830+
X = csr_container(X)
831+
assert all(row != col for row, col in zip(*X.nonzero()))
832+
X_copy = X.copy()
833+
OPTICS(metric="precomputed").fit(X)
834+
# Make sure that we did not modify `X` in-place even by creating
835+
# explicit 0s values.
836+
assert X.nnz == X_copy.nnz
837+
assert_array_equal(X.toarray(), X_copy.toarray())
838+
839+
819840
def test_optics_predecessor_correction_ordering():
820841
"""Check that cluster correction using predecessor is working as expected.
821842

sklearn/datasets/_base.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
import hashlib
1212
import os
1313
import shutil
14+
import time
15+
import warnings
1416
from collections import namedtuple
1517
from importlib import resources
1618
from numbers import Integral
1719
from os import environ, listdir, makedirs
1820
from os.path import expanduser, isdir, join, splitext
1921
from pathlib import Path
22+
from urllib.error import URLError
2023
from urllib.request import urlretrieve
2124

2225
import numpy as np
@@ -1408,7 +1411,7 @@ def _sha256(path):
14081411
return sha256hash.hexdigest()
14091412

14101413

1411-
def _fetch_remote(remote, dirname=None):
1414+
def _fetch_remote(remote, dirname=None, n_retries=3, delay=1):
14121415
"""Helper function to download a remote dataset into path
14131416
14141417
Fetch a dataset pointed by remote's url, save into path using remote's
@@ -1424,14 +1427,35 @@ def _fetch_remote(remote, dirname=None):
14241427
dirname : str
14251428
Directory to save the file to.
14261429
1430+
n_retries : int, default=3
1431+
Number of retries when HTTP errors are encountered.
1432+
1433+
.. versionadded:: 1.5
1434+
1435+
delay : int, default=1
1436+
Number of seconds between retries.
1437+
1438+
.. versionadded:: 1.5
1439+
14271440
Returns
14281441
-------
14291442
file_path: str
14301443
Full path of the created file.
14311444
"""
14321445

14331446
file_path = remote.filename if dirname is None else join(dirname, remote.filename)
1434-
urlretrieve(remote.url, file_path)
1447+
while True:
1448+
try:
1449+
urlretrieve(remote.url, file_path)
1450+
break
1451+
except (URLError, TimeoutError):
1452+
if n_retries == 0:
1453+
# If no more retries are left, re-raise the caught exception.
1454+
raise
1455+
warnings.warn(f"Retry downloading from url: {remote.url}")
1456+
n_retries -= 1
1457+
time.sleep(delay)
1458+
14351459
checksum = _sha256(file_path)
14361460
if remote.checksum != checksum:
14371461
raise OSError(

sklearn/datasets/_california_housing.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323

2424
import logging
2525
import tarfile
26+
from numbers import Integral, Real
2627
from os import PathLike, makedirs, remove
2728
from os.path import exists
2829

2930
import joblib
3031
import numpy as np
3132

3233
from ..utils import Bunch
33-
from ..utils._param_validation import validate_params
34+
from ..utils._param_validation import Interval, validate_params
3435
from . import get_data_home< 10000 /div>
3536
from ._base import (
3637
RemoteFileMetadata,
@@ -57,11 +58,19 @@
5758
"download_if_missing": ["boolean"],
5859
"return_X_y": ["boolean"],
5960
"as_frame": ["boolean"],
61+
"n_retries": [Interval(Integral, 1, None, closed="left")],
62+
"delay": [Interval(Real, 0.0, None, closed="neither")],
6063
},
6164
prefer_skip_nested_validation=True,
6265
)
6366
def fetch_california_housing(
64-
*, data_home=None, download_if_missing=True, return_X_y=False, as_frame=False
67+
*,
68+
data_home=None,
69+
download_if_missing=True,
70+
return_X_y=False,
71+
as_frame=False,
72+
n_retries=3,
73+
delay=1.0,
6574
):
6675
"""Load the California housing dataset (regression).
6776
@@ -97,6 +106,16 @@ def fetch_california_housing(
97106
98107
.. versionadded:: 0.23
99108
109+
n_retries : int, default=3
110+
Number of retries when HTTP errors are encountered.
111+
112+
.. versionadded:: 1.5
113+
114+
delay : float, default=1.0
115+
Number of seconds between retries.
116+
117+
.. versionadded:: 1.5
118+
100119
Returns
101120
-------
102121
dataset : :class:`~sklearn.utils.Bunch`
@@ -154,7 +173,12 @@ def fetch_california_housing(
154173
"Downloading Cal. housing from {} to {}".format(ARCHIVE.url, data_home)
155174
)
156175

157-
archive_path = _fetch_remote(ARCHIVE, dirname=data_home)
176+
archive_path = _fetch_remote(
177+
ARCHIVE,
178+
dirname=data_home,
179+
n_retries=n_retries,
180+
delay=delay,
181+
)
158182

159183
with tarfile.open(mode="r:gz", name=archive_path) as f:
160184
cal_housing = np.loadtxt(

sklearn/datasets/_covtype.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
import logging
1818
import os
1919
from gzip import GzipFile
20+
from numbers import Integral, Real
2021
from os.path import exists, join
2122
from tempfile import TemporaryDirectory
2223

2324
import joblib
2425
import numpy as np
2526

2627
from ..utils import Bunch, check_random_state
27-
from ..utils._param_validation import validate_params
28+
from ..utils._param_validation import Interval, validate_params
2829
from . import get_data_home
2930
from ._base import (
3031
RemoteFileMetadata,
@@ -71,6 +72,8 @@
7172
"shuffle": ["boolean"],
7273
"return_X_y": ["boolean"],
7374
"as_frame": ["boolean"],
75+
"n_retries": [Interval(Integral, 1, None, closed="left")],
76+
"delay": [Interval(Real, 0.0, None, closed="neither")],
7477
},
7578
prefer_skip_nested_validation=True,
7679
)
@@ -82,6 +85,8 @@ def fetch_covtype(
8285
shuffle=False,
8386
return_X_y=False,
8487
as_frame=False,
88+
n_retries=3,
89+
delay=1.0,
8590
):
8691
"""Load the covertype dataset (classification).
8792
@@ -129,6 +134,16 @@ def fetch_covtype(
129134
130135
.. versionadded:: 0.24
131136
137+
n_retries : int, default=3
138+
Number of retries when HTTP errors are encountered.
139+
140+
.. versionadded:: 1.5
141+
142+
delay : float, default=1.0
143+
Number of seconds between retries.
144+
145+
.. versionadded:: 1.5
146+
132147
Returns
133148
-------
134149
dataset : :class:`~sklearn.utils.Bunch`
@@ -183,7 +198,9 @@ def fetch_covtype(
183198
# os.rename to atomically move the data files to their target location.
184199
with TemporaryDirectory(dir=covtype_dir) as temp_dir:
185200
logger.info(f"Downloading {ARCHIVE.url}")
186-
archive_path = _fetch_remote(ARCHIVE, dirname=temp_dir)
201+
archive_path = _fetch_remote(
202+
ARCHIVE, dirname=temp_dir, _retries=n_retries, delay=delay
203+
)
187204
Xy = np.genfromtxt(GzipFile(filename=archive_path), delimiter=",")
188205

189206
X = Xy[:, :-1]

0 commit comments

Comments
 (0)
0