diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py index 2305f60ebbb54..d37b308a6d5d8 100644 --- a/sklearn/datasets/base.py +++ b/sklearn/datasets/base.py @@ -322,7 +322,7 @@ def load_wine(return_X_y=False): 'proline']) -def load_iris(return_X_y=False): +def load_iris(return_X_y=False, as_frame=False): """Load and return the iris dataset (classification). The iris dataset is a classic and very easy multi-class classification @@ -382,11 +382,19 @@ def load_iris(return_X_y=False): if return_X_y: return data, target + feature_names = ['sepal length (cm)', 'sepal width (cm)', + 'petal length (cm)', 'petal width (cm)'] + + if as_frame: + from pandas import Series, DataFrame + data_frame = DataFrame(data, columns=feature_names) + target_series = pd.Series(target, name="class") + return data_frame, target_series + return Bunch(data=data, target=target, target_names=target_names, DESCR=fdescr, - feature_names=['sepal length (cm)', 'sepal width (cm)', - 'petal length (cm)', 'petal width (cm)'], + feature_names=feature_names, filename=iris_csv_filename) diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 091735b986a3f..61c495c01f1de 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -29,6 +29,7 @@ from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import SkipTest DATA_HOME = tempfile.mkdtemp(prefix="scikit_learn_data_home_test_") @@ -202,6 +203,15 @@ def test_load_iris(): check_return_X_y(res, partial(load_iris)) +def test_load_iris_as_frame(): + try: + data_frame, target_series = load_iris(as_frame=True) + assert_equal(data_frame.shape, (150, 4)) + assert_equal(target_series.shape[0], 150) + except IOError as : + SkipTest("Pandas is needed to run the test") + + def test_load_wine(): res = load_wine() assert_equal(res.data.shape, (178, 13))