8000 ENH allows to overwrite read_csv parameter in fetch_openml (#26433) · primait/scikit-learn@0ed4374 · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 0ed4374

Browse files
glemaitrethomasjpfanadrinjalali
authored andcommitted
ENH allows to overwrite read_csv parameter in fetch_openml (scikit-learn#26433)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent f24c502 commit 0ed4374

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

doc/whats_new/v1.3.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,11 @@ Changelog
271271
- |Fix| :func:`datasets.fetch_openml` returns improved data types when
272272
`as_frame=True` and `parser="liac-arff"`. :pr:`26386` by `Thomas Fan`_.
273273

274+
- |Enhancement| Allows to overwrite the parameters used to open the ARFF file using
275+
the parameter `read_csv_kwargs` in :func:`datasets.fetch_openml` when using the
276+
pandas parser.
277+
:pr:`26433` by :user:`Guillaume Lemaitre <glemaitre>`.
278+
274279
:mod:`sklearn.decomposition`
275280
............................
276281

sklearn/datasets/_openml.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def _load_arff_response(
428428
md5_checksum: str,
429429
n_retries: int = 3,
430430
delay: float = 1.0,
431+
read_csv_kwargs: Optional[Dict] = None,
431432
):
432433
"""Load the ARFF data associated with the OpenML URL.
433434
@@ -470,6 +471,18 @@ def _load_arff_response(
470471
md5_checksum : str
471472
The MD5 checksum provided by OpenML to check the data integrity.
472473
474+
n_retries : int, default=3
475+
The number of times to retry downloading the data if it fails.
476+
477+
delay : float, default=1.0
478+
The delay between two consecutive downloads in seconds.
479+
480+
read_csv_kwargs : dict, default=None
481+
Keyword arguments to pass to `pandas.read_csv` when using the pandas parser.
482+
It allows to overwrite the default options.
483+
484+
.. versionadded:: 1.3
485+
473486
Returns
474487
-------
475488
X : {ndarray, sparse matrix, dataframe}
@@ -506,13 +519,14 @@ def _open_url_and_load_gzip_file(url, data_home, n_retries, delay, arff_params):
506519
with closing(gzip_file):
507520
return load_arff_from_gzip_file(gzip_file, **arff_params)
508521

509-
arff_params = dict(
522+
arff_params: Dict = dict(
510523
parser=parser,
511524
output_type=output_type,
512525
openml_columns_info=openml_columns_info,
513526
feature_names_to_select=feature_names_to_select,
514527
target_names_to_select=target_names_to_select,
515528
shape=shape,
529+
read_csv_kwargs=read_csv_kwargs or {},
516530
)
517531
try:
518532
X, y, frame, categories = _open_url_and_load_gzip_file(
@@ -530,7 +544,7 @@ def _open_url_and_load_gzip_file(url, data_home, n_retries, delay, arff_params):
530544
# A parsing error could come from providing the wrong quotechar
531545
# to pandas. By default, we use a double quote. Thus, we retry
532546
# with a single quote before to raise the error.
533-
arff_params["read_csv_kwargs"] = {"quotechar": "'"}
547+
arff_params["read_csv_kwargs"].update(quotechar="'")
534548
X, y, frame, categories = _open_url_and_load_gzip_file(
535549
url, data_home, n_retries, delay, arff_params
536550
)
@@ -552,6 +566,7 @@ def _download_data_to_bunch(
552566
n_retries: int = 3,
553567
delay: float = 1.0,
554568
parser: str,
569+
read_csv_kwargs: Optional[Dict] = None,
555570
):
556571
"""Download ARFF data, load it to a specific container and create to Bunch.
557572
@@ -598,6 +613,12 @@ def _download_data_to_bunch(
598613
parser : {"liac-arff", "pandas"}
599614
The parser used to parse the ARFF file.
600615
616+
read_csv_kwargs : dict, default=None
617+
Keyword arguments to pass to `pandas.read_csv` when using the pandas parser.
618+
It allows to overwrite the default options.
619+
620+
.. versionadded:: 1.3
621+
601622
Returns
602623
-------
603624
data : :class:`~sklearn.utils.Bunch`
@@ -657,6 +678,7 @@ def _download_data_to_bunch(
657678
md5_checksum=md5_checksum,
658679
n_retries=n_retries,
659680
delay=delay,
681+
read_csv_kwargs=read_csv_kwargs,
660682
)
661683

662684
return Bunch(
@@ -725,6 +747,7 @@ def fetch_openml(
725747
n_retries: int = 3,
726748
delay: float = 1.0,
727749
parser: Optional[str] = "warn",
750+
read_csv_kwargs: Optional[Dict] = None,
728751
):
729752
"""Fetch dataset from openml by name or dataset id.
730753
@@ -829,6 +852,13 @@ def fetch_openml(
829852
warning. Therefore, an `ImportError` will be raised from 1.4 if
830853
the dataset is dense and pandas is not installed.
831854
855+
read_csv_kwargs : dict, default=None
856+
Keyword arguments passed to :func:`pandas.read_csv` when loading the data
857+
from a ARFF file and using the pandas parser. It can allows to
858+
overwrite some default parameters.
859+
860+
.. versionadded:: 1.3
861+
832862
Returns
833863
-------
834864
data : :class:`~sklearn.utils.Bunch`
@@ -1096,6 +1126,7 @@ def fetch_openml(
10961126
n_retries=n_retries,
10971127
delay=delay,
10981128
parser=parser_,
1129+
read_csv_kwargs=read_csv_kwargs,
10991130
)
11001131

11011132
if return_X_y:

sklearn/datasets/tests/test_openml.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,34 @@ def test_dataset_with_openml_warning(monkeypatch, gzip_response):
13541354
fetch_openml(data_id=data_id, cache=False, as_frame=False, parser="liac-arff")
13551355

13561356

1357+
def test_fetch_openml_overwrite_default_params_read_csv(monkeypatch):
1358+
"""Check that we can overwrite the default parameters of `read_csv`."""
1359+
pytest.importorskip("pandas")
1360+
data_id = 1590
1361+
_monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)
1362+
1363+
common_params = {
1364+
"data_id": data_id,
1365+
"as_frame": True,
1366+
"cache": False,
1367+
"parser": "pandas",
1368+
}
1369+
1370+
# By default, the initial spaces are skipped. We checked that setting the parameter
1371+
# `skipinitialspace` to False will have an effect.
1372+
adult_without_spaces = fetch_openml(**common_params)
1373+
adult_with_spaces = fetch_openml(
1374+
**common_params, read_csv_kwargs={"skipinitialspace": False}
1375+
)
1376+
assert all(
1377+
cat.startswith(" ") for cat in adult_with_spaces.frame["class"].cat.categories
1378+
)
1379+
assert not any(
1380+
cat.startswith(" ")
1381+
for cat in adult_without_spaces.frame["class"].cat.categories
1382+
)
1383+
1384+
13571385
###############################################################################
13581386
# Test cache, retry mechanisms, checksum, etc.
13591387

0 commit comments

Comments
 (0)
0