8000 ENH fetch_openml should support return_X_y (#11840) · amueller/scikit-learn@52a36b4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 52a36b4

Browse files
panzhufengqinhanmin2014
authored andcommitted
ENH fetch_openml should support return_X_y (scikit-learn#11840)
1 parent 757663c commit 52a36b4

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

sklearn/datasets/openml.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def _verify_target_data_type(features_dict, target_columns):
347347

348348

349349
def fetch_openml(name=None, version='active', data_id=None, data_home=None,
350-
target_column='default-target', cache=True):
350+
target_column='default-target', cache=True, return_X_y=False):
351351
"""Fetch dataset from openml by name or dataset id.
352352
353353
Datasets are uniquely identified by either an integer ID or by a
@@ -395,6 +395,10 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
395395
cache : boolean, default=True
396396
Whether to cache downloaded datasets using joblib.
397397
398+
return_X_y : boolean, default=False.
399+
If True, returns ``(data, target)`` instead of a Bunch object. See
400+
below for more information about the `data` and `target` objects.
401+
398402
Returns
399403
-------
400404
@@ -416,6 +420,8 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
416420
details : dict
417421
More metadata from OpenML
418422
423+
(data, target) : tuple if ``return_X_y`` is True
424+
419425
.. note:: EXPERIMENTAL
420426
421427
This interface is **experimental** as at version 0.20 and
@@ -557,6 +563,9 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
557563
elif y.shape[1] == 0:
558564
y = None
559565

566+
if return_X_y:
567+
return X, y
568+
560569
bunch = Bunch(
561570
data=X, target=y, feature_names=data_columns,
562571
DESCR=description, details=data_description,

sklearn/datasets/tests/test_openml.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
assert_raise_message)
1616
from sklearn.externals.six import string_types
1717
from sklearn.externals.six.moves.urllib.error import HTTPError
18+
from sklearn.datasets.tests.test_common import check_return_X_y
19+
from functools import partial
1820

1921

2022
currdir = os.path.dirname(os.path.abspath(__file__))
@@ -124,6 +126,11 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
124126
# np.isnan doesn't work on CSR matrix
125127
assert (np.count_nonzero(np.isnan(data_by_id.data)) ==
126128
expected_missing)
129+
130+
# test return_X_y option
131+
fetch_func = partial(fetch_openml, data_id=data_id, cache=False,
132+
target_column=target_column)
133+
check_return_X_y(data_by_id, fetch_func)
127134
return data_by_id
128135

129136

0 commit comments

Comments
 (0)
0