8000 ENH adding as_frame functionality for CA housing dataset loader (#15950) · panpiort8/scikit-learn@e6fb20b · GitHub
[go: up one dir, main page]

Skip to content

Commit e6fb20b

Browse files
Reshama Shaikhgitsteph
authored and
Pan Jan
committed
ENH adding as_frame functionality for CA housing dataset loader (scikit-learn#15950)
Co-authored-by: Stephanie Andrews <byronic.string@gmail.com>
1 parent 9aa27d3 commit e6fb20b

File tree

4 files changed

+96
-8
lines changed

4 files changed

+96
-8
lines changed

doc/whats_new/v0.23.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ Changelog
5858
:func:`datasets.make_moons` now accept two-element tuple.
5959
:pr:`15707` by :user:`Maciej J Mikulski <mjmikulski>`.
6060

61+
- |Feature| :func:`datasets.fetch_california_housing` now supports
62+
heterogeneous data using pandas by setting `as_frame=True`. :pr:`15950`
63+
by :user:`Stephanie Andrews <gitsteph>` and
64+
:user:`Reshama Shaikh <reshamas>`.
65+
6166
:mod:`sklearn.feature_extraction`
6267
.................................
6368

sklearn/datasets/_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ..utils import Bunch
1919
from ..utils import check_random_state
20+
from ..utils import check_pandas_support
2021

2122
import numpy as np
2223

@@ -67,6 +68,17 @@ def clear_data_home(data_home=None):
6768
shutil.rmtree(data_home)
6869

6970

71+
def _convert_data_dataframe(caller_name, data, target,
72+
feature_names, target_names):
73+
pd = check_pandas_support('{} with as_frame=True'.format(caller_name))
74+
data_df = pd.DataFrame(data, columns=feature_names)
75+
target_df = pd.DataFrame(target, columns=target_names)
76+
combined_df = pd.concat([data_df, target_df], axis=1)
77+
X = combined_df[feature_names]
78+
y = combined_df[target_names]
79+
return combined_df, X, y
80+
81+
7082
def load_files(container_path, description=None, categories=None,
7183
load_content=True, shuffle=True, encoding=None,
7284
decode_error='strict', random_state=0):

sklearn/datasets/_california_housing.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import joblib
3232

3333
from . import get_data_home
34+
from ._base import _convert_data_dataframe
3435
from ._base import _fetch_remote
3536
from ._base import _pkl_filepath
3637
from ._base import RemoteFileMetadata
@@ -49,7 +50,7 @@
4950

5051

5152
def fetch_california_housing(data_home=None, download_if_missing=True,
52-
return_X_y=False):
53+
return_X_y=False, as_frame=False):
5354
"""Load the California housing dataset (regression).
5455
5556
============== ==============
@@ -78,15 +79,24 @@ def fetch_california_housing(data_home=None, download_if_missing=True,
7879
7980
.. versionadded:: 0.20
8081
82+
as_frame : boolean, default=False
83+
If True, the data is a pandas DataFrame including columns with
84+
appropriate dtypes (numeric, string or categorical). The target is
85+
a pandas DataFrame or Series depending on the number of target_columns.
86+
87+
.. versionadded:: 0.23
88+
8189
Returns
8290
-------
8391
dataset : dict-like object with the following attributes:
8492
8593
dataset.data : ndarray, shape [20640, 8]
8694
Each row corresponding to the 8 feature values in order.
95+
If ``as_frame`` is True, ``data`` is a pandas object.
8796
8897
dataset.target : numpy array of shape (20640,)
8998
Each value corresponds to the average house value in units of 100,000.
99+
If ``as_frame`` is True, ``target`` is a pandas object.
90100
91101
dataset.feature_names : array of length 8
92102
Array of ordered feature names used in the dataset.
@@ -98,6 +108,12 @@ def fetch_california_housing(data_home=None, download_if_missing=True,
98108
99109
.. versionadded:: 0.20
100110
111+
frame : pandas DataFrame
112+
Only present when `as_frame=True`. DataFrame with ``data`` and
113+
``target``.
114+
115+
.. versionadded:: 0.23
116+
101117
Notes
102118
-----
103119
@@ -155,10 +171,24 @@ def fetch_california_housing(data_home=None, download_if_missing=True,
155171
with open(join(module_path, 'descr', 'california_housing.rst')) as dfile:
156172
descr = dfile.read()
157173

174+
X = data
175+
y = target
176+
177+
frame = None
178+
target_names = ["MedHouseVal", ]
179+
if as_frame:
180+
frame, X, y = _convert_data_dataframe("fetch_california_housing",
181+
data,
182+
target,
183+
feature_names,
184+
target_names)
185+
158186
if return_X_y:
159-
return data, target
187+
return X, y
160188

161-
return Bunch(data=data,
162-
target=target,
189+
return Bunch(data=X,
190+
target=y,
191+
frame=frame,
192+
target_names=target_names,
163193
feature_names=feature_names,
164194
DESCR=descr)

sklearn/datasets/tests/test_california_housing.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
Skipped if california_housing is not already downloaded to data_home.
44
"""
55

6+
import pytest
7+
68
from sklearn.datasets import fetch_california_housing
7-
from sklearn.utils._testing import SkipTest
89
from sklearn.datasets.tests.test_common import check_return_X_y
910
from functools import partial
1011

@@ -13,14 +14,54 @@ def fetch(*args, **kwargs):
1314
return fetch_california_housing(*args, download_if_missing=False, **kwargs)
1415

1516

16-
def test_fetch():
17+
def _is_california_housing_dataset_not_available():
1718
try:
18-
data = fetch()
19+
fetch_california_housing(download_if_missing=False)
20+
return False
1921
except IOError:
20-
raise SkipTest("California housing dataset can not be loaded.")
22+
return True
23+
24+
25+
@pytest.mark.skipif(
26+
_is_california_housing_dataset_not_available(),
27+
reason='Download California Housing dataset to run this test'
28+
)
29+
def test_fetch():
30+
data = fetch()
2131
assert((20640, 8) == data.data.shape)
2232
assert((20640, ) == data.target.shape)
2333

2434
# test return_X_y option
2535
fetch_func = partial(fetch)
2636
check_return_X_y(data, fetch_func)
37+
38+
39+
@pytest.mark.skipif(
40+
_is_california_housing_dataset_not_available(),
41+
reason='Download California Housing dataset to run this test'
42+
)
43+
def test_fetch_asframe():
44+
pd = pytest.importorskip('pandas')
45+
bunch = fetch(as_frame=True)
46+
frame = bunch.frame
47+
assert hasattr(bunch, 'frame') is True
48+
assert frame.shape == (20640, 9)
49+
assert isinstance(bunch.data, pd.DataFrame)
50+
assert isinstance(bunch.target, pd.DataFrame)
51+
52+
53+
@pytest.mark.skipif(
54+
_is_california_housing_dataset_not_available(),
55+
reason='Download California Housing dataset to run this test'
56+
)
57+
def test_pandas_dependency_message():
58+
try:
59+
import pandas # noqa
60+
pytest.skip("This test requires pandas to be not installed")
61+
except ImportError:
62+
# Check that pandas is imported lazily and that an informative error
63+
# message is raised when pandas is missing:
64+
expected_msg = ('fetch_california_housing with as_frame=True'
65+
' requires pandas')
66+
with pytest.raises(ImportError, match=expected_msg):
67+
fetch_california_housing(as_frame=True)

0 commit comments

Comments
 (0)
0