diff --git a/.circleci/config.yml b/.circleci/config.yml index f9e7769d0c02d..fb5bb2a686f69 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -10,7 +10,6 @@ jobs: - PYTHON_VERSION: 3.5 - NUMPY_VERSION: 1.11.0 - SCIPY_VERSION: 0.17.0 - - PANDAS_VERSION: 0.18.0 - MATPLOTLIB_VERSION: 1.5.1 - SCIKIT_IMAGE_VERSION: 0.12.3 steps: diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0c975003159ef..50baa55f607c8 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -20,6 +20,7 @@ jobs: INSTALL_MKL: 'false' NUMPY_VERSION: '1.11.0' SCIPY_VERSION: '0.17.0' + PANDAS_VERSION: '*' CYTHON_VERSION: '*' PILLOW_VERSION: '4.0.0' MATPLOTLIB_VERSION: '1.5.1' diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 8269f336bbe57..00d816d30c023 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -39,6 +39,12 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +:mod:`sklearn.datasets` +....................... + +- |Feature| :func:`datasets.fetch_openml` now supports heterogeneous data using pandas + by setting `as_frame=True`. :pr:`13902` by `Thomas Fan`_. + :mod:`sklearn.decomposition` .................. diff --git a/examples/compose/plot_column_transformer_mixed_types.py b/examples/compose/plot_column_transformer_mixed_types.py index 264ae7495296c..0f6c5d3c222c6 100644 --- a/examples/compose/plot_column_transformer_mixed_types.py +++ b/examples/compose/plot_column_transformer_mixed_types.py @@ -24,10 +24,10 @@ # # License: BSD 3 clause -import pandas as pd import numpy as np from sklearn.compose import ColumnTransformer +from sklearn.datasets import fetch_openml from sklearn.pipeline import Pipeline from sklearn.impute import SimpleImputer from sklearn.preprocessing import StandardScaler, OneHotEncoder @@ -37,9 +37,13 @@ np.random.seed(0) # Read data from Titanic dataset. -titanic_url = ('https://raw.githubusercontent.com/amueller/' - 'scipy-2017-sklearn/091d371/notebooks/datasets/titanic3.csv') -data = pd.read_csv(titanic_url) +titantic = fetch_openml(data_id=40945, as_frame=True) +X = titantic.data +y = titantic.target + +# Alternatively X and y can be obtained directly from the frame attribute: +# X = titantic.frame.drop('survived', axis=1) +# y = titantic.frame['survived'] # We will train our classifier with the following features: # Numeric Features: @@ -71,9 +75,6 @@ clf = Pipeline(steps=[('preprocessor', preprocessor), ('classifier', LogisticRegression())]) -X = data.drop('survived', axis=1) -y = data['survived'] - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) clf.fit(X_train, y_train) diff --git a/sklearn/datasets/openml.py b/sklearn/datasets/openml.py index 6f76ee15e2e40..3d82027e29118 100644 --- a/sklearn/datasets/openml.py +++ b/sklearn/datasets/openml.py @@ -8,6 +8,7 @@ from functools import wraps import itertools from collections.abc import Generator +from collections import OrderedDict from urllib.request import urlopen, Request @@ -18,6 +19,9 @@ from .base import get_data_home from urllib.error import HTTPError from ..utils import Bunch +from ..utils import get_chunk_n_rows +from ..utils import _chunk_generator +from ..utils import check_pandas_support # noqa __all__ = ['fetch_openml'] @@ -263,6 +267,69 @@ def _convert_arff_data(arff_data, col_slice_x, col_slice_y, shape=None): raise ValueError('Unexpected Data Type obtained from arff.') +def _feature_to_dtype(feature): + """Map feature to dtype for pandas DataFrame + """ + if feature['data_type'] == 'string': + return object + elif feature['data_type'] == 'nominal': + return 'category' + # only numeric, integer, real are left + elif (feature['number_of_missing_values'] != '0' or + feature['data_type'] in ['numeric', 'real']): + # cast to floats when there are any missing values + return np.float64 + elif feature['data_type'] == 'integer': + return np.int64 + raise ValueError('Unsupported feature: {}'.format(feature)) + + +def _convert_arff_data_dataframe(arrf, columns, features_dict): + """Convert the ARFF object into a pandas DataFrame. + + Parameters + ---------- + arrf : dict + As obtained from liac-arff object. + + columns : list + Columns from dataframe to return. + + features_dict : dict + Maps feature name to feature info from openml. + + Returns + ------- + dataframe : pandas DataFrame + """ + pd = check_pandas_support('fetch_openml with as_frame=True') + + attributes = OrderedDict(arrf['attributes']) + arrf_columns = list(attributes) + + # calculate chunksize + first_row = next(arrf['data']) + first_df = pd.DataFrame([first_row], columns=arrf_columns) + + row_bytes = first_df.memory_usage(deep=True).sum() + chunksize = get_chunk_n_rows(row_bytes) + + # read arrf data with chunks + columns_to_keep = [col for col in arrf_columns if col in columns] + dfs = [] + dfs.append(first_df[columns_to_keep]) + for data in _chunk_generator(arrf['data'], chunksize): + dfs.append(pd.DataFrame(data, columns=arrf_columns)[columns_to_keep]) + df = pd.concat(dfs) + + for column in columns_to_keep: + dtype = _feature_to_dtype(features_dict[column]) + if dtype == 'category': + dtype = pd.api.types.CategoricalDtype(attributes[column]) + df[column] = df[column].astype(dtype, copy=False) + return df + + def _get_data_info_by_name(name, version, data_home): """ Utilizes the openml dataset listing api to find a dataset by @@ -436,7 +503,8 @@ def _valid_data_column_names(features_list, target_columns): def fetch_openml(name=None, version='active', data_id=None, data_home=None, - target_column='default-target', cache=True, return_X_y=False): + target_column='default-target', cache=True, return_X_y=False, + as_frame=False): """Fetch dataset from openml by name or dataset id. Datasets are uniquely identified by either an integer ID or by a @@ -489,26 +557,39 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None, If True, returns ``(data, target)`` instead of a Bunch object. See below for more information about the `data` and `target` objects. + as_frame : boolean, default=False + If True, the data is a pandas DataFrame including columns with + appropriate dtypes (numeric, string or categorical). The target is + a pandas DataFrame or Series depending on the number of target_columns. + The Bunch will contain a ``frame`` attribute with the target and the + data. If ``return_X_y`` is True, then ``(data, target)`` will be pandas + DataFrames or Series as describe above. + Returns ------- data : Bunch Dictionary-like object, with attributes: - data : np.array or scipy.sparse.csr_matrix of floats + data : np.array, scipy.sparse.csr_matrix of floats, or pandas DataFrame The feature matrix. Categorical features are encoded as ordinals. - target : np.array + target : np.array, pandas Series or DataFrame The regression target or classification labels, if applicable. - Dtype is float if numeric, and object if categorical. + Dtype is float if numeric, and object if categorical. If + ``as_frame`` is True, ``target`` is a pandas object. DESCR : str The full description of the dataset feature_names : list The names of the dataset columns - categories : dict + categories : dict or None Maps each categorical feature name to a list of values, such - that the value encoded as i is ith in the list. + that the value encoded as i is ith in the list. If ``as_frame`` + is True, this is None. details : dict More metadata from OpenML + frame : pandas DataFrame + Only present when `as_frame=True`. DataFrame with ``data`` and + ``target``. (data, target) : tuple if ``return_X_y`` is True @@ -568,41 +649,52 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None, warn("OpenML raised a warning on the dataset. It might be " "unusable. Warning: {}".format(data_description['warning'])) + return_sparse = False + if data_description['format'].lower() == 'sparse_arff': + return_sparse = True + + if as_frame and return_sparse: + raise ValueError('Cannot return dataframe with sparse data') + # download data features, meta-info about column types features_list = _get_data_features(data_id, data_home) - for feature in features_list: - if 'true' in (feature['is_ignore'], feature['is_row_identifier']): - continue - if feature['data_type'] == 'string': - raise ValueError('STRING attributes are not yet supported') + if not as_frame: + for feature in features_list: + if 'true' in (feature['is_ignore'], feature['is_row_identifier']): + continue + if feature['data_type'] == 'string': + raise ValueError('STRING attributes are not supported for ' + 'array representation. Try as_frame=True') if target_column == "default-target": # determines the default target based on the data feature results # (which is currently more reliable than the data description; # see issue: https://github.com/openml/OpenML/issues/768) - target_column = [feature['name'] for feature in features_list - if feature['is_target'] == 'true'] + target_columns = [feature['name'] for feature in features_list + if feature['is_target'] == 'true'] elif isinstance(target_column, str): # for code-simplicity, make target_column by default a list - target_column = [target_column] + target_columns = [target_column] elif target_column is None: - target_column = [] - elif not isinstance(target_column, list): + target_columns = [] + elif isinstance(target_column, list): + target_columns = target_column + else: raise TypeError("Did not recognize type of target_column" "Should be str, list or None. Got: " "{}".format(type(target_column))) data_columns = _valid_data_column_names(features_list, - target_column) + target_columns) # prepare which columns and data types should be returned for the X and y features_dict = {feature['name']: feature for feature in features_list} # XXX: col_slice_y should be all nominal or all numeric - _verify_target_data_type(features_dict, target_column) + _verify_target_data_type(features_dict, target_columns) col_slice_y = [int(features_dict[col_name]['index']) - for col_name in target_column] + for col_name in target_columns] col_slice_x = [int(features_dict[col_name]['index']) for col_name in data_columns] @@ -615,10 +707,6 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None, 'columns. '.format(feat['name'], nr_missing)) # determine arff encoding to return - return_sparse = False - if data_description['format'].lower() == 'sparse_arff': - return_sparse = True - if not return_sparse: data_qualities = _get_data_qualities(data_id, data_home) shape = _get_data_shape(data_qualities) @@ -631,46 +719,62 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None, # obtain the data arff = _download_data_arff(data_description['file_id'], return_sparse, - data_home) - - # nominal attributes is a dict mapping from the attribute name to the - # possible values. Includes also the target column (which will be popped - # off below, before it will be packed in the Bunch object) - nominal_attributes = {k: v for k, v in arff['attributes'] - if isinstance(v, list) and - k in data_columns + target_column} - - X, y = _convert_arff_data(arff['data'], col_slice_x, col_slice_y, shape) - - is_classification = {col_name in nominal_attributes - for col_name in target_column} - if not is_classification: - # No target - pass - elif all(is_classification): - y = np.hstack([np.take(np.asarray(nominal_attributes.pop(col_name), - dtype='O'), - y[:, i:i+1].astype(int, copy=False)) - for i, col_name in enumerate(target_column)]) - elif any(is_classification): - raise ValueError('Mix of nominal and non-nominal targets is not ' - 'currently supported') + data_home, encode_nominal=not as_frame) description = "{}\n\nDownloaded from openml.org.".format( data_description.pop('description')) - # reshape y back to 1-D array, if there is only 1 target column; back - # to None if there are not target columns - if y.shape[1] == 1: - y = y.reshape((-1,)) - elif y.shape[1] == 0: - y = None + nominal_attributes = None + frame = None + if as_frame: + columns = data_columns + target_columns + frame = _convert_arff_data_dataframe(arff, columns, features_dict) + X = frame[data_columns] + if len(target_columns) >= 2: + y = frame[target_columns] + elif len(target_columns) == 1: + y = frame[target_columns[0]] + else: + y = None + else: + # nominal attributes is a dict mapping from the attribute name to the + # possible values. Includes also the target column (which will be + # popped off below, before it will be packed in the Bunch object) + nominal_attributes = {k: v for k, v in arff['attributes'] + if isinstance(v, list) and + k in data_columns + target_columns} + + X, y = _convert_arff_data(arff['data'], col_slice_x, + col_slice_y, shape) + + is_classification = {col_name in nominal_attributes + for col_name in target_columns} + if not is_classification: + # No target + pass + elif all(is_classification): + y = np.hstack([ + np.take( + np.asarray(nominal_attributes.pop(col_name), dtype='O'), + y[:, i:i + 1].astype(int, copy=False)) + for i, col_name in enumerate(target_columns) + ]) + elif any(is_classification): + raise ValueError('Mix of nominal and non-nominal targets is not ' + 'currently supported') + + # reshape y back to 1-D array, if there is only 1 target column; back + # to None if there are not target columns + if y.shape[1] == 1: + y = y.reshape((-1,)) + elif y.shape[1] == 0: + y = None if return_X_y: return X, y bunch = Bunch( - data=X, target=y, feature_names=data_columns, + data=X, target=y, frame=frame, feature_names=data_columns, DESCR=description, details=data_description, categories=nominal_attributes, url="https://www.openml.org/d/{}".format(data_id)) diff --git a/sklearn/datasets/tests/data/openml/40945/api-v1-json-data-qualities-40945.json.gz b/sklearn/datasets/tests/data/openml/40945/api-v1-json-data-qualities-40945.json.gz new file mode 100644 index 0000000000000..279a0bd82ad66 Binary files /dev/null and b/sklearn/datasets/tests/data/openml/40945/api-v1-json-data-qualities-40945.json.gz differ diff --git a/sklearn/datasets/tests/data/openml/40945/data-v1-download-16826755.arff.gz b/sklearn/datasets/tests/data/openml/40945/data-v1-download-16826755.arff.gz new file mode 100644 index 0000000000000..824fd370dd582 Binary files /dev/null and b/sklearn/datasets/tests/data/openml/40945/data-v1-download-16826755.arff.gz differ diff --git a/sklearn/datasets/tests/test_openml.py b/sklearn/datasets/tests/test_openml.py index 9c8200731aa6d..de13f96675f16 100644 --- a/sklearn/datasets/tests/test_openml.py +++ b/sklearn/datasets/tests/test_openml.py @@ -9,15 +9,18 @@ import sklearn import pytest +from sklearn import config_context from sklearn.datasets import fetch_openml from sklearn.datasets.openml import (_open_openml_url, _get_data_description_by_id, _download_data_arff, _get_local_path, - _retry_with_clean_cache) + _retry_with_clean_cache, + _feature_to_dtype) from sklearn.utils.testing import (assert_warns_message, assert_raise_message) from sklearn.utils import is_scalar_nan +from sklearn.utils.testing import assert_allclose, assert_array_equal from urllib.error import HTTPError from sklearn.datasets.tests.test_common import check_return_X_y from functools import partial @@ -255,6 +258,428 @@ def _mock_urlopen(request): context.setattr(sklearn.datasets.openml, 'urlopen', _mock_urlopen) +@pytest.mark.parametrize('feature, expected_dtype', [ + ({'data_type': 'string', 'number_of_missing_values': '0'}, object), + ({'data_type': 'string', 'number_of_missing_values': '1'}, object), + ({'data_type': 'numeric', 'number_of_missing_values': '0'}, np.float64), + ({'data_type': 'numeric', 'number_of_missing_values': '1'}, np.float64), + ({'data_type': 'real', 'number_of_missing_values': '0'}, np.float64), + ({'data_type': 'real', 'number_of_missing_values': '1'}, np.float64), + ({'data_type': 'integer', 'number_of_missing_values': '0'}, np.int64), + ({'data_type': 'integer', 'number_of_missing_values': '1'}, np.float64), + ({'data_type': 'nominal', 'number_of_missing_values': '0'}, 'category'), + ({'data_type': 'nominal', 'number_of_missing_values': '1'}, 'category'), +]) +def test_feature_to_dtype(feature, expected_dtype): + assert _feature_to_dtype(feature) == expected_dtype + + +@pytest.mark.parametrize('feature', [ + {'data_type': 'datatime', 'number_of_missing_values': '0'} +]) +def test_feature_to_dtype_error(feature): + msg = 'Unsupported feature: {}'.format(feature) + with pytest.raises(ValueError, match=msg): + _feature_to_dtype(feature) + + +def test_fetch_openml_iris_pandas(monkeypatch): + # classification dataset with numeric only columns + pd = pytest.importorskip('pandas') + CategoricalDtype = pd.api.types.CategoricalDtype + data_id = 61 + data_shape = (150, 4) + target_shape = (150, ) + frame_shape = (150, 5) + + target_dtype = CategoricalDtype(['Iris-setosa', 'Iris-versicolor', + 'Iris-virginica']) + data_dtypes = [np.float64] * 4 + data_names = ['sepallength', 'sepalwidth', 'petallength', 'petalwidth'] + target_name = 'class' + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + + bunch = fetch_openml(data_id=data_id, as_frame=True, cache=False) + data = bunch.data + target = bunch.target + frame = bunch.frame + + assert isinstance(data, pd.DataFrame) + assert np.all(data.dtypes == data_dtypes) + assert data.shape == data_shape + assert np.all(data.columns == data_names) + assert np.all(bunch.feature_names == data_names) + + assert isinstance(target, pd.Series) + assert target.dtype == target_dtype + assert target.shape == target_shape + assert target.name == target_name + + assert isinstance(frame, pd.DataFrame) + assert frame.shape == frame_shape + assert np.all(frame.dtypes == data_dtypes + [target_dtype]) + + +def test_fetch_openml_iris_pandas_equal_to_no_frame(monkeypatch): + # as_frame = True returns the same underlying data as as_frame = False + pytest.importorskip('pandas') + data_id = 61 + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + + frame_bunch = fetch_openml(data_id=data_id, as_frame=True, cache=False) + frame_data = frame_bunch.data + frame_target = frame_bunch.target + + norm_bunch = fetch_openml(data_id=data_id, as_frame=False, cache=False) + norm_data = norm_bunch.data + norm_target = norm_bunch.target + + assert_allclose(norm_data, frame_data) + assert_array_equal(norm_target, frame_target) + + +def test_fetch_openml_iris_multitarget_pandas(monkeypatch): + # classification dataset with numeric only columns + pd = pytest.importorskip('pandas') + CategoricalDtype = pd.api.types.CategoricalDtype + data_id = 61 + data_shape = (150, 3) + target_shape = (150, 2) + frame_shape = (150, 5) + target_column = ['petalwidth', 'petallength'] + + cat_dtype = CategoricalDtype(['Iris-setosa', 'Iris-versicolor', + 'Iris-virginica']) + data_dtypes = [np.float64, np.float64] + [cat_dtype] + data_names = ['sepallength', 'sepalwidth', 'class'] + target_dtypes = [np.float64, np.float64] + target_names = ['petalwidth', 'petallength'] + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + + bunch = fetch_openml(data_id=data_id, as_frame=True, cache=False, + target_column=target_column) + data = bunch.data + target = bunch.target + frame = bunch.frame + + assert isinstance(data, pd.DataFrame) + assert np.all(data.dtypes == data_dtypes) + assert data.shape == data_shape + assert np.all(data.columns == data_names) + assert np.all(bunch.feature_names == data_names) + + assert isinstance(target, pd.DataFrame) + assert np.all(target.dtypes == target_dtypes) + assert target.shape == target_shape + assert np.all(target.columns == target_names) + + assert isinstance(frame, pd.DataFrame) + assert frame.shape == frame_shape + assert np.all(frame.dtypes == [np.float64] * 4 + [cat_dtype]) + + +def test_fetch_openml_anneal_pandas(monkeypatch): + # classification dataset with numeric and categorical columns + pd = pytest.importorskip('pandas') + CategoricalDtype = pd.api.types.CategoricalDtype + + data_id = 2 + target_column = 'class' + data_shape = (11, 38) + target_shape = (11,) + frame_shape = (11, 39) + expected_data_categories = 32 + expected_data_floats = 6 + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + + bunch = fetch_openml(data_id=data_id, as_frame=True, + target_column=target_column, cache=False) + data = bunch.data + target = bunch.target + frame = bunch.frame + + assert isinstance(data, pd.DataFrame) + assert data.shape == data_shape + n_categories = len([dtype for dtype in data.dtypes + if isinstance(dtype, CategoricalDtype)]) + n_floats = len([dtype for dtype in data.dtypes if dtype.kind == 'f']) + assert expected_data_categories == n_categories + assert expected_data_floats == n_floats + + assert isinstance(target, pd.Series) + assert target.shape == target_shape + assert isinstance(target.dtype, CategoricalDtype) + + assert isinstance(frame, pd.DataFrame) + assert frame.shape == frame_shape + + +def test_fetch_openml_cpu_pandas(monkeypatch): + # regression dataset with numeric and categorical columns + pd = pytest.importorskip('pandas') + CategoricalDtype = pd.api.types.CategoricalDtype + data_id = 561 + data_shape = (209, 7) + target_shape = (209, ) + frame_shape = (209, 8) + + cat_dtype = CategoricalDtype(['adviser', 'amdahl', 'apollo', 'basf', + 'bti', 'burroughs', 'c.r.d', 'cdc', + 'cambex', 'dec', 'dg', 'formation', + 'four-phase', 'gould', 'hp', 'harris', + 'honeywell', 'ibm', 'ipl', 'magnuson', + 'microdata', 'nas', 'ncr', 'nixdorf', + 'perkin-elmer', 'prime', 'siemens', + 'sperry', 'sratus', 'wang']) + data_dtypes = [cat_dtype] + [np.float64] * 6 + feature_names = ['vendor', 'MYCT', 'MMIN', 'MMAX', 'CACH', + 'CHMIN', 'CHMAX'] + target_name = 'class' + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + bunch = fetch_openml(data_id=data_id, as_frame=True, cache=False) + data = bunch.data + target = bunch.target + frame = bunch.frame + + assert isinstance(data, pd.DataFrame) + assert data.shape == data_shape + assert np.all(data.dtypes == data_dtypes) + assert np.all(data.columns == feature_names) + assert np.all(bunch.feature_names == feature_names) + + assert isinstance(target, pd.Series) + assert target.shape == target_shape + assert target.dtype == np.float64 + assert target.name == target_name + + assert isinstance(frame, pd.DataFrame) + assert frame.shape == frame_shape + + +def test_fetch_openml_australian_pandas_error_sparse(monkeypatch): + data_id = 292 + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + + msg = 'Cannot return dataframe with sparse data' + with pytest.raises(ValueError, match=msg): + fetch_openml(data_id=data_id, as_frame=True, cache=False) + + +def test_convert_arff_data_dataframe_warning_low_memory_pandas(monkeypatch): + pytest.importorskip('pandas') + + data_id = 1119 + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + + msg = 'Could not adhere to working_memory config.' + with pytest.warns(UserWarning, match=msg): + with config_context(working_memory=1e-6): + fetch_openml(data_id=data_id, as_frame=True, cache=False) + + +def test_fetch_openml_adultcensus_pandas_return_X_y(monkeypatch): + pd = pytest.importorskip('pandas') + CategoricalDtype = pd.api.types.CategoricalDtype + + data_id = 1119 + data_shape = (10, 14) + target_shape = (10, ) + + expected_data_categories = 8 + expected_data_floats = 6 + target_column = 'class' + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + X, y = fetch_openml(data_id=data_id, as_frame=True, cache=False, + return_X_y=True) + assert isinstance(X, pd.DataFrame) + assert X.shape == data_shape + n_categories = len([dtype for dtype in X.dtypes + if isinstance(dtype, CategoricalDtype)]) + n_floats = len([dtype for dtype in X.dtypes if dtype.kind == 'f']) + assert expected_data_categories == n_categories + assert expected_data_floats == n_floats + + assert isinstance(y, pd.Series) + assert y.shape == target_shape + assert y.name == target_column + + +def test_fetch_openml_adultcensus_pandas(monkeypatch): + pd = pytest.importorskip('pandas') + CategoricalDtype = pd.api.types.CategoricalDtype + + # Check because of the numeric row attribute (issue #12329) + data_id = 1119 + data_shape = (10, 14) + target_shape = (10, ) + frame_shape = (10, 15) + + expected_data_categories = 8 + expected_data_floats = 6 + target_column = 'class' + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + bunch = fetch_openml(data_id=data_id, as_frame=True, cache=False) + data = bunch.data + target = bunch.target + frame = bunch.frame + + assert isinstance(data, pd.DataFrame) + assert data.shape == data_shape + n_categories = len([dtype for dtype in data.dtypes + if isinstance(dtype, CategoricalDtype)]) + n_floats = len([dtype for dtype in data.dtypes if dtype.kind == 'f']) + assert expected_data_categories == n_categories + assert expected_data_floats == n_floats + + assert isinstance(target, pd.Series) + assert target.shape == target_shape + assert target.name == target_column + + assert isinstance(frame, pd.DataFrame) + assert frame.shape == frame_shape + + +def test_fetch_openml_miceprotein_pandas(monkeypatch): + # JvR: very important check, as this dataset defined several row ids + # and ignore attributes. Note that data_features json has 82 attributes, + # and row id (1), ignore attributes (3) have been removed. + pd = pytest.importorskip('pandas') + CategoricalDtype = pd.api.types.CategoricalDtype + + data_id = 40966 + data_shape = (7, 77) + target_shape = (7, ) + frame_shape = (7, 78) + + target_column = 'class' + frame_n_categories = 1 + frame_n_floats = 77 + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + bunch = fetch_openml(data_id=data_id, as_frame=True, cache=False) + data = bunch.data + target = bunch.target + frame = bunch.frame + + assert isinstance(data, pd.DataFrame) + assert data.shape == data_shape + assert np.all(data.dtypes == np.float64) + + assert isinstance(target, pd.Series) + assert isinstance(target.dtype, CategoricalDtype) + assert target.shape == target_shape + assert target.name == target_column + + assert isinstance(frame, pd.DataFrame) + assert frame.shape == frame_shape + n_categories = len([dtype for dtype in frame.dtypes + if isinstance(dtype, CategoricalDtype)]) + n_floats = len([dtype for dtype in frame.dtypes if dtype.kind == 'f']) + assert frame_n_categories == n_categories + assert frame_n_floats == n_floats + + +def test_fetch_openml_emotions_pandas(monkeypatch): + # classification dataset with multiple targets (natively) + pd = pytest.importorskip('pandas') + CategoricalDtype = pd.api.types.CategoricalDtype + + data_id = 40589 + target_column = ['amazed.suprised', 'happy.pleased', 'relaxing.calm', + 'quiet.still', 'sad.lonely', 'angry.aggresive'] + data_shape = (13, 72) + target_shape = (13, 6) + frame_shape = (13, 78) + + expected_frame_categories = 6 + expected_frame_floats = 72 + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + bunch = fetch_openml(data_id=data_id, as_frame=True, cache=False, + target_column=target_column) + data = bunch.data + target = bunch.target + frame = bunch.frame + + assert isinstance(data, pd.DataFrame) + assert data.shape == data_shape + + assert isinstance(target, pd.DataFrame) + assert target.shape == target_shape + assert np.all(target.columns == target_column) + + assert isinstance(frame, pd.DataFrame) + assert frame.shape == frame_shape + n_categories = len([dtype for dtype in frame.dtypes + if isinstance(dtype, CategoricalDtype)]) + n_floats = len([dtype for dtype in frame.dtypes if dtype.kind == 'f']) + assert expected_frame_categories == n_categories + assert expected_frame_floats == n_floats + + +def test_fetch_openml_titanic_pandas(monkeypatch): + # dataset with strings + pd = pytest.importorskip('pandas') + CategoricalDtype = pd.api.types.CategoricalDtype + + data_id = 40945 + data_shape = (1309, 13) + target_shape = (1309, ) + frame_shape = (1309, 14) + name_to_dtype = { + 'pclass': np.float64, + 'name': object, + 'sex': CategoricalDtype(['female', 'male']), + 'age': np.float64, + 'sibsp': np.float64, + 'parch': np.float64, + 'ticket': object, + 'fare': np.float64, + 'cabin': object, + 'embarked': CategoricalDtype(['C', 'Q', 'S']), + 'boat': object, + 'body': np.float64, + 'home.dest': object, + 'survived': CategoricalDtype(['0', '1']) + } + + frame_columns = ['pclass', 'survived', 'name', 'sex', 'age', 'sibsp', + 'parch', 'ticket', 'fare', 'cabin', 'embarked', + 'boat', 'body', 'home.dest'] + frame_dtypes = [name_to_dtype[col] for col in frame_columns] + feature_names = ['pclass', 'name', 'sex', 'age', 'sibsp', + 'parch', 'ticket', 'fare', 'cabin', 'embarked', + 'boat', 'body', 'home.dest'] + target_name = 'survived' + + _monkey_patch_webbased_functions(monkeypatch, data_id, True) + bunch = fetch_openml(data_id=data_id, as_frame=True, cache=False) + data = bunch.data + target = bunch.target + frame = bunch.frame + + assert isinstance(data, pd.DataFrame) + assert data.shape == data_shape + assert np.all(data.columns == feature_names) + + assert isinstance(target, pd.Series) + assert target.shape == target_shape + assert target.name == target_name + assert target.dtype == name_to_dtype[target_name] + + assert isinstance(frame, pd.DataFrame) + assert frame.shape == frame_shape + assert np.all(frame.dtypes == frame_dtypes) + + @pytest.mark.parametrize('gzip_response', [True, False]) def test_fetch_openml_iris(monkeypatch, gzip_response): # classification dataset with numeric only columns @@ -661,12 +1086,13 @@ def test_warn_ignore_attribute(monkeypatch, gzip_response): @pytest.mark.parametrize('gzip_response', [True, False]) -def test_string_attribute(monkeypatch, gzip_response): +def test_string_attribute_without_dataframe(monkeypatch, gzip_response): data_id = 40945 _monkey_patch_webbased_functions(monkeypatch, data_id, gzip_response) # single column test assert_raise_message(ValueError, - 'STRING attributes are not yet supported', + ('STRING attributes are not supported for ' + 'array representation. Try as_frame=True'), fetch_openml, data_id=data_id, cache=False) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index daf7e7763235d..4528c2ba0caeb 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -3,6 +3,7 @@ """ from collections.abc import Sequence from contextlib import contextmanager +from itertools import islice import numbers import platform import struct @@ -477,6 +478,17 @@ def safe_sqr(X, copy=True): return X +def _chunk_generator(gen, chunksize): + """Chunk generator, ``gen`` into lists of length ``chunksize``. The last + chunk may have a length less than ``chunksize``.""" + while True: + chunk = list(islice(gen, chunksize)) + if chunk: + yield chunk + else: + return + + def gen_batches(n, batch_size, min_batch_size=0): """Generator to create slices containing batch_size elements, from 0 to n. @@ -824,3 +836,24 @@ def check_matplotlib_support(caller_name): "{} requires matplotlib. You can install matplotlib with " "`pip install matplotlib`".format(caller_name) ) from e + + +def check_pandas_support(caller_name): + """Raise ImportError with detailed error message if pandsa is not + installed. + + Plot utilities like :func:`fetch_openml` should lazily import + pandas and call this helper before any computation. + + Parameters + ---------- + caller_name : str + The name of the caller that requires pandas. + """ + try: + import pandas # noqa + return pandas + except ImportError as e: + raise ImportError( + "{} requires pandas.".format(caller_name) + ) from e