diff --git a/sklearn/datasets/__init__.py b/sklearn/datasets/__init__.py index 77dac99c1d970..b9436b232b7ed 100644 --- a/sklearn/datasets/__init__.py +++ b/sklearn/datasets/__init__.py @@ -3,6 +3,14 @@ including methods to load and fetch popular reference datasets. It also features some artificial data generators. """ +from .base import Boston +from .base import BreastCancer +from .base import Diabetes +from .base import Digits +from .base import Iris +from .base import Linnerud +from .base import Wine +from .base import SampleImages from .base import load_breast_cancer from .base import load_boston from .base import load_diabetes @@ -52,7 +60,11 @@ from .rcv1 import fetch_rcv1 -__all__ = ['clear_data_home', +__all__ = ['Boston', + 'BreastCancer', + 'clear_data_home', + 'Diabetes', + 'Digits', 'dump_svmlight_file', 'fetch_20newsgroups', 'fetch_20newsgroups_vectorized', @@ -67,6 +79,8 @@ 'fetch_kddcup99', 'fetch_openml', 'get_data_home', + 'Iris', + 'Linnerud', 'load_boston', 'load_diabetes', 'load_digits', @@ -99,4 +113,6 @@ 'make_sparse_uncorrelated', 'make_spd_matrix', 'make_swiss_roll', - 'mldata_filename'] + 'mldata_filename', + 'SampleImages', + 'Wine'] diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py index 34214bf3f58de..5448cfae8ee20 100644 --- a/sklearn/datasets/base.py +++ b/sklearn/datasets/base.py @@ -6,8 +6,6 @@ # 2010 Fabian Pedregosa # 2010 Olivier Grisel # License: BSD 3 clause -import os -import csv import sys import shutil from collections import namedtuple @@ -19,11 +17,17 @@ from ..utils import check_random_state import numpy as np +from numpy import ndarray from urllib.request import urlretrieve +from typing import Dict, Tuple, Union, Any +from abc import ABCMeta, abstractmethod +from ..utils import deprecated + RemoteFileMetadata = namedtuple('RemoteFileMetadata', ['filename', 'url', 'checksum']) +Dataset = Union[Bunch, Tuple[Any, Any]] def get_data_home(data_home=None): @@ -204,17 +208,16 @@ def load_files(container_path, description=None, categories=None, DESCR=description) -def load_data(module_path, data_file_name): +@deprecated("'load_data' was renamed to" + "'SimpleCSVLocalDatasetLoader.read_X_y_csv'" + "in version 0.21 and will be removed in 0.23.") +def load_data(path: str) -> Tuple[ndarray, ndarray, ndarray]: """Loads data from module_path/data/data_file_name. Parameters ---------- - module_path : string - The module path. - - data_file_name : string - Name of csv file to be loaded from - module_path/data/data_file_name. For example 'wine_data.csv'. + path : string + The data file path. Returns ------- @@ -230,22 +233,18 @@ def load_data(module_path, data_file_name): A 1D array containing the names of the classifications. For example target_names[0] is the name of the target[0] class. """ - with open(join(module_path, 'data', data_file_name)) as csv_file: - data_file = csv.reader(csv_file) - temp = next(data_file) - n_samples = int(temp[0]) - n_features = int(temp[1]) - target_names = np.array(temp[2:]) - data = np.empty((n_samples, n_features)) - target = np.empty((n_samples,), dtype=np.int) + return SimpleCSVLocalDatasetLoader.read_X_y_csv(path=path) - for i, ir in enumerate(data_file): - data[i] = np.asarray(ir[:-1], dtype=np.float64) - target[i] = np.asarray(ir[-1], dtype=np.int) - return data, target, target_names +def _attempt_cast_to_int(arr: ndarray) -> ndarray: + arri = arr.astype('int', casting='unsafe') + if (arr == arri).all(): + return arri + return arr +@deprecated("'load_wine' was renamed to 'Wine().load'" + "in version 0.21 and will be removed in 0.23.") def load_wine(return_X_y=False): """Load and return the wine dataset (classification). @@ -289,40 +288,18 @@ def load_wine(return_X_y=False): Let's say you are interested in the samples 10, 80, and 140, and want to know their class name. - >>> from sklearn.datasets import load_wine - >>> data = load_wine() + >>> from sklearn.datasets import Wine + >>> data = Wine().load() >>> data.target[[10, 80, 140]] array([0, 1, 2]) >>> list(data.target_names) ['class_0', 'class_1', 'class_2'] """ - module_path = dirname(__file__) - data, target, target_names = load_data(module_path, 'wine_data.csv') - - with open(join(module_path, 'descr', 'wine_data.rst')) as rst_file: - fdescr = rst_file.read() - - if return_X_y: - return data, target - - return Bunch(data=data, target=target, - target_names=target_names, - DESCR=fdescr, - feature_names=['alcohol', - 'malic_acid', - 'ash', - 'alcalinity_of_ash', - 'magnesium', - 'total_phenols', - 'flavanoids', - 'nonflavanoid_phenols', - 'proanthocyanins', - 'color_intensity', - 'hue', - 'od280/od315_of_diluted_wines', - 'proline']) + return Wine().load(return_X_y=return_X_y) +@deprecated("'load_iris' was renamed to 'Iris().load'" + "in version 0.21 and will be removed in 0.23.") def load_iris(return_X_y=False): """Load and return the iris dataset (classification). @@ -373,31 +350,18 @@ def load_iris(return_X_y=False): Let's say you are interested in the samples 10, 25, and 50, and want to know their class name. - >>> from sklearn.datasets import load_iris - >>> data = load_iris() + >>> from sklearn.datasets import Iris + >>> data = Iris().load() >>> data.target[[10, 25, 50]] array([0, 0, 1]) >>> list(data.target_names) ['setosa', 'versicolor', 'virginica'] """ - module_path = dirname(__file__) - data, target, target_names = load_data(module_path, 'iris.csv') - iris_csv_filename = join(module_path, 'data', 'iris.csv') - - with open(join(module_path, 'descr', 'iris.rst')) as rst_file: - fdescr = rst_file.read() - - if return_X_y: - return data, target - - return Bunch(data=data, target=target, - target_names=target_names, - DESCR=fdescr, - feature_names=['sepal length (cm)', 'sepal width (cm)', - 'petal length (cm)', 'petal width (cm)'], - filename=iris_csv_filename) + return Iris().load(return_X_y=return_X_y) +@deprecated("'load_breast_cancer' was renamed to 'BreastCancer().load'" + "in version 0.21 and will be removed in 0.23.") def load_breast_cancer(return_X_y=False): """Load and return the breast cancer wisconsin dataset (classification). @@ -445,46 +409,18 @@ def load_breast_cancer(return_X_y=False): Let's say you are interested in the samples 10, 50, and 85, and want to know their class name. - >>> from sklearn.datasets import load_breast_cancer - >>> data = load_breast_cancer() + >>> from sklearn.datasets import BreastCancer + >>> data = BreastCancer().load() >>> data.target[[10, 50, 85]] array([0, 1, 0]) >>> list(data.target_names) ['malignant', 'benign'] """ - module_path = dirname(__file__) - data, target, target_names = load_data(module_path, 'breast_cancer.csv') - csv_filename = join(module_path, 'data', 'breast_cancer.csv') - - with open(join(module_path, 'descr', 'breast_cancer.rst')) as rst_file: - fdescr = rst_file.read() - - feature_names = np.array(['mean radius', 'mean texture', - 'mean perimeter', 'mean area', - 'mean smoothness', 'mean compactness', - 'mean concavity', 'mean concave points', - 'mean symmetry', 'mean fractal dimension', - 'radius error', 'texture error', - 'perimeter error', 'area error', - 'smoothness error', 'compactness error', - 'concavity error', 'concave points error', - 'symmetry error', 'fractal dimension error', - 'worst radius', 'worst texture', - 'worst perimeter', 'worst area', - 'worst smoothness', 'worst compactness', - 'worst concavity', 'worst concave points', - 'worst symmetry', 'worst fractal dimension']) - - if return_X_y: - return data, target - - return Bunch(data=data, target=target, - target_names=target_names, - DESCR=fdescr, - feature_names=feature_names, - filename=csv_filename) + return BreastCancer().load(return_X_y=return_X_y) +@deprecated("'load_digits' was renamed to 'Digits().load'" + "in version 0.21 and will be removed in 0.23.") def load_digits(n_class=10, return_X_y=False): """Load and return the digits dataset (classification). @@ -531,8 +467,8 @@ def load_digits(n_class=10, return_X_y=False): -------- To load the data and visualize the images:: - >>> from sklearn.datasets import load_digits - >>> digits = load_digits() + >>> from sklearn.datasets import Digits + >>> digits = Digits().load() >>> print(digits.data.shape) (1797, 64) >>> import matplotlib.pyplot as plt #doctest: +SKIP @@ -540,31 +476,11 @@ def load_digits(n_class=10, return_X_y=False): >>> plt.matshow(digits.images[0]) #doctest: +SKIP >>> plt.show() #doctest: +SKIP """ - module_path = dirname(__file__) - data = np.loadtxt(join(module_path, 'data', 'digits.csv.gz'), - delimiter=',') - with open(join(module_path, 'descr', 'digits.rst')) as f: - descr = f.read() - target = data[:, -1].astype(np.int) - flat_data = data[:, :-1] - images = flat_data.view() - images.shape = (-1, 8, 8) - - if n_class < 10: - idx = target < n_class - flat_data, target = flat_data[idx], target[idx] - images = images[idx] - - if return_X_y: - return flat_data, target - - return Bunch(data=flat_data, - target=target, - target_names=np.arange(10), - images=images, - DESCR=descr) + return Digits(n_class=n_class).load(return_X_y=return_X_y) +@deprecated("'load_diabetes' was renamed to 'Diabetes().load'" + "in version 0.21 and will be removed in 0.23.") def load_diabetes(return_X_y=False): """Load and return the diabetes dataset (regression). @@ -598,26 +514,11 @@ def load_diabetes(return_X_y=False): .. versionadded:: 0.18 """ - module_path = dirname(__file__) - base_dir = join(module_path, 'data') - data_filename = join(base_dir, 'diabetes_data.csv.gz') - data = np.loadtxt(data_filename) - target_filename = join(base_dir, 'diabetes_target.csv.gz') - target = np.loadtxt(target_filename) - - with open(join(module_path, 'descr', 'diabetes.rst')) as rst_file: - fdescr = rst_file.read() - - if return_X_y: - return data, target - - return Bunch(data=data, target=target, DESCR=fdescr, - feature_names=['age', 'sex', 'bmi', 'bp', - 's1', 's2', 's3', 's4', 's5', 's6'], - data_filename=data_filename, - target_filename=target_filename) + return Diabetes().load(return_X_y=return_X_y) +@deprecated("'load_linnerud' was renamed to 'Linnerud().load'" + "in version 0.21 and will be removed in 0.23.") def load_linnerud(return_X_y=False): """Load and return the linnerud dataset (multivariate regression). @@ -654,34 +555,11 @@ def load_linnerud(return_X_y=False): .. versionadded:: 0.18 """ - base_dir = join(dirname(__file__), 'data/') - data_filename = join(base_dir, 'linnerud_exercise.csv') - target_filename = join(base_dir, 'linnerud_physiological.csv') - - # Read data - data_exercise = np.loadtxt(data_filename, skiprows=1) - data_physiological = np.loadtxt(target_filename, skiprows=1) - - # Read header - with open(data_filename) as f: - header_exercise = f.readline().split() - with open(target_filename) as f: - header_physiological = f.readline().split() - - with open(dirname(__file__) + '/descr/linnerud.rst') as f: - descr = f.read() - - if return_X_y: - return data_exercise, data_physiological - - return Bunch(data=data_exercise, feature_names=header_exercise, - target=data_physiological, - target_names=header_physiological, - DESCR=descr, - data_filename=data_filename, - target_filename=target_filename) + return Linnerud().load(return_X_y=return_X_y) +@deprecated("'load_boston' was renamed to 'Boston().load'" + "in version 0.21 and will be removed in 0.23.") def load_boston(return_X_y=False): """Load and return the boston house-prices dataset (regression). @@ -722,43 +600,16 @@ def load_boston(return_X_y=False): Examples -------- - >>> from sklearn.datasets import load_boston - >>> boston = load_boston() + >>> from sklearn.datasets import Boston + >>> boston = Boston().load() >>> print(boston.data.shape) (506, 13) """ - module_path = dirname(__file__) - - fdescr_name = join(module_path, 'descr', 'boston_house_prices.rst') - with open(fdescr_name) as f: - descr_text = f.read() - - data_file_name = join(module_path, 'data', 'boston_house_prices.csv') - with open(data_file_name) as f: - data_file = csv.reader(f) - temp = next(data_file) - n_samples = int(temp[0]) - n_features = int(temp[1]) - data = np.empty((n_samples, n_features)) - target = np.empty((n_samples,)) - temp = next(data_file) # names of features - feature_names = np.array(temp) - - for i, d in enumerate(data_file): - data[i] = np.asarray(d[:-1], dtype=np.float64) - target[i] = np.asarray(d[-1], dtype=np.float64) - - if return_X_y: - return data, target - - return Bunch(data=data, - target=target, - # last column is target value - feature_names=feature_names[:-1], - DESCR=descr_text, - filename=data_file_name) + return Boston().load(return_X_y=return_X_y) +@deprecated("'load_sample_images' was renamed to 'SampleImages().load'" + "in version 0.21 and will be removed in 0.23.") def load_sample_images(): """Load sample images for image manipulation. @@ -777,8 +628,8 @@ def load_sample_images(): -------- To load the data and visualize the images: - >>> from sklearn.datasets import load_sample_images - >>> dataset = load_sample_images() #doctest: +SKIP + >>> from sklearn.datasets import SampleImages + >>> dataset = SampleImages().load() #doctest: +SKIP >>> len(dataset.images) #doctest: +SKIP 2 >>> first_img_data = dataset.images[0] #doctest: +SKIP @@ -787,23 +638,12 @@ def load_sample_images(): >>> first_img_data.dtype #doctest: +SKIP dtype('uint8') """ - # import PIL only when needed - from ..externals._pilutil import imread - - module_path = join(dirname(__file__), "images") - with open(join(module_path, 'README.txt')) as f: - descr = f.read() - filenames = [join(module_path, filename) - for filename in os.listdir(module_path) - if filename.endswith(".jpg")] - # Load image data for each image in the source folder. - images = [imread(filename) for filename in filenames] - - return Bunch(images=images, - filenames=filenames, - DESCR=descr) + return SampleImages().load() +@deprecated("'load_sample_image' was refactored" + "in version 0.21 and will be removed in 0.23." + "Use 'SampleImages().load' instead") def load_sample_image(image_name): """Load the numpy array of a single sample image @@ -821,28 +661,288 @@ def load_sample_image(image_name): Examples --------- - - >>> from sklearn.datasets import load_sample_image - >>> china = load_sample_image('china.jpg') # doctest: +SKIP + >>> from sklearn.datasets import SampleImages + >>> china = SampleImages('china.jpg').load().images[0] # doctest: +SKIP >>> china.dtype # doctest: +SKIP dtype('uint8') >>> china.shape # doctest: +SKIP (427, 640, 3) - >>> flower = load_sample_image('flower.jpg') # doctest: +SKIP + >>> flower = SampleImages('flower.jpg').load().images[0] # doctest: +SKIP >>> flower.dtype # doctest: +SKIP dtype('uint8') >>> flower.shape # doctest: +SKIP (427, 640, 3) """ - images = load_sample_images() - index = None - for i, filename in enumerate(images.filenames): - if filename.endswith(image_name): - index = i - break - if index is None: - raise AttributeError("Cannot find sample image: %s" % image_name) - return images.images[index] + return SampleImages(image_name).load().images[0] + + +class DatasetLoader(object, metaclass=ABCMeta): + """Abstract class for all dataset loaders in scikit-learn.""" + + def load(self, return_X_y=False) -> Dataset: + bunch = self._raw_data_to_bunch() + if return_X_y: + return bunch.data, bunch.target + return bunch + + @abstractmethod + def _raw_data_to_bunch(self) -> Bunch: + raise NotImplementedError + + +class LocalDatasetLoader(DatasetLoader): + _module_path = dirname(__file__) + _data_dir = join(_module_path, 'data') + _descr_dir = join(_module_path, 'descr') + _images_dir = join(_module_path, 'images') + + @property + def X_file(self) -> Union[str, ndarray]: + raise NotImplementedError + + @property + def y_file(self) -> str: + return self.X_file + + @property + def descr_file(self): + return self.X_file.split('.', maxsplit=1)[0] + '.rst' + + @property + def local_data_paths(self) -> Dict[str, str]: + try: + return { + 'X': join(self._data_dir, self.X_file), + 'y': join(self._data_dir, self.y_file), + 'descr': join(self._descr_dir, self.descr_file), + } + except TypeError: + return { + 'X': np.array([join(self._data_dir, f) for f in self.X_file]), + 'y': np.array([join(self._data_dir, f) for f in self.y_file]), + 'descr': join(self._descr_dir, self.descr_file), + } + + @property + def feature_names(self) -> ndarray: + raise NotImplementedError + + @property + def target_names(self) -> ndarray: + raise NotImplementedError + + _attempt_cast_to_int = staticmethod(_attempt_cast_to_int) + + def _raw_data_to_bunch(self) -> Bunch: + bunch = self.read_data() + bunch.DESCR = self._read_description() + return self.process(bunch) + + @abstractmethod + def read_data(self) -> Bunch: + raise NotImplementedError + + @abstractmethod + def process(self, bunch: Bunch) -> Bunch: + bunch.target = self._attempt_cast_to_int(bunch.target) + return bunch + + def _read_description(self, path: str = None) -> str: + descr_path = path or self.local_data_paths['descr'] + with open(descr_path) as descr: + return descr.read() + + def _make_bunch(self, X, y, target_names, + description, images=None) -> Bunch: + return Bunch( + data=X, target=y, + feature_names=self.feature_names, + target_names=target_names, + DESCR=description, + images=images, + data_filename=self.local_data_paths['X'], + target_filename=self.local_data_paths['y'], + filename=self.local_data_paths['X']) # Backwards compatible (0.21) + + +class SimpleCSVLocalDatasetLoader(LocalDatasetLoader): + """Reads a .csv file with: + The first row containing: + - the sample size (int) + - the number of features (int) + - [Optional] the target names corresponding + to their 'int' representation in the y column + The second [Optional] row containing: + - feature variable names with type 'str' + X features in columns [:-1] with type 'int' or 'float' + y target in column [-1] with type 'int' or 'float' + + If you try to read a .csv file containing 'object' or + 'str' type variable values, it will fail! + """ + def read_data(self) -> Bunch: + X, y, target_names = self.read_X_y_csv(self.local_data_paths['X']) + if self.target_names.size > 0: # class property takes precedence + target_names = self.target_names + return self._make_bunch(X, y, target_names, None, None) + + def process(self, bunch: Bunch) -> Bunch: + return super().process(bunch) + + @staticmethod + def read_X_y_csv(path: str) -> Tuple[ndarray, ndarray, ndarray]: + with open(path) as f: + firstline = f.readline().rstrip().split(',') + n_features = int(firstline[1]) + col_ixs = tuple(range(n_features + 1)) + target_names = np.array(firstline[2:]) + + csv_arr = np.genfromtxt(path, delimiter=',', usecols=col_ixs, + skip_header=1, dtype=np.float) + data, target = csv_arr[:, :n_features], csv_arr[:, n_features] + + mask_row_all_nan = ~np.isnan(data).all(axis=1) + data, target = data[mask_row_all_nan], target[mask_row_all_nan] + return data, target, target_names + + +class Wine(SimpleCSVLocalDatasetLoader): + X_file = 'wine_data.csv' + feature_names = np.array([ + 'alcohol', 'malic_acid', + 'ash', 'alcalinity_of_ash', + 'magnesium', 'total_phenols', + 'flavanods', 'nonflavanoid_phenols', + 'proanthocyanins', 'color_intensity', + 'hue', 'od280/od315_of_diluted_wines', + 'proline']) + target_names = np.array([]) + + +class Iris(SimpleCSVLocalDatasetLoader): + X_file = 'iris.csv' + feature_names = np.array([ + 'sepal length (cm)', 'sepal width (cm)', + 'petal length (cm)', 'petal width (cm)']) + target_names = np.array([]) + + +class BreastCancer(SimpleCSVLocalDatasetLoader): + X_file = 'breast_cancer.csv' + feature_names = np.array([ + 'mean radius', 'mean texture', + 'mean perimeter', 'mean area', + 'mean smoothness', 'mean compactness', + 'mean concavity', 'mean concave points', + 'mean symmetry', 'mean fractal dimension', + 'radius error', 'texture error', + 'perimeter error', 'area error', + 'smoothness error', 'compactness error', + 'concavity error', 'concave points error', + 'symmetry error', 'fractal dimension error', + 'worst radius', 'worst texture', + 'worst perimeter', 'worst area', + 'worst smoothness', 'worst compactness', + 'worst concavity', 'worst concave points', + 'worst symmetry', 'worst frctal dimension']) + target_names = np.array([]) + + +class Digits(LocalDatasetLoader): + X_file = 'digits.csv.gz' + feature_names = np.array([list(map(lambda d: (d, x).__repr__(), + np.arange(1, 9, dtype='int'))) + for x in np.arange(1, 9, dtype='int')]) + target_names = np.arange(10) + + def __init__(self, n_class=10): + self.n_class = n_class + + def read_data(self): + X = np.loadtxt(self.local_data_paths['X'], delimiter=',') + y = X[:, -1].astype('int') + return self._make_bunch(X, y, self.target_names, None, None) + + def process(self, bunch: Bunch) -> Bunch: + X, y = bunch.data, bunch.target + flat_X = X[:, :-1] + images = flat_X.view() + images.shape = (-1, 8, 8) + if self.n_class < 10: + idx = y < self.n_class + flat_X, y = flat_X[idx], y[idx] + images = images[idx] + bunch.data, bunch.target, bunch.images = flat_X, y, images + return bunch + + +class Diabetes(LocalDatasetLoader): + X_file = 'diabetes_data.csv.gz' + y_file = 'diabetes_target.csv.gz' + descr_file = 'diabetes.rst' + feature_names = np.array([ + 'age', 'sex', 'bmi', 'bp', + 's1', 's2', 's3', 's4', 's5', 's6']) + target_names = np.array(['progression']) + + def read_data(self): + X = np.loadtxt(self.local_data_paths['X'], dtype='float') + y = np.loadtxt(self.local_data_paths['y'], dtype='float') + return self._make_bunch(X, y, self.target_names, None, None) + + def process(self, bunch: Bunch): + return super().process(bunch) + + +class Linnerud(LocalDatasetLoader): + X_file = 'linnerud_exercise.csv' + y_file = 'linnerud_physiological.csv' + descr_file = 'linnerud.rst' + feature_names = np.array(['Chins', 'Situps', 'Jumps']) + target_names = np.array(['Weight', 'Waist', 'Pulse']) + + def read_data(self) -> Bunch: + X = np.loadtxt(self.local_data_paths['X'], skiprows=1, dtype='float') + y = np.loadtxt(self.local_data_paths['y'], skiprows=1, dtype='float') + return self._make_bunch(X, y, self.target_names, None, None) + + def process(self, bunch: Bunch) -> Bunch: + return bunch + + +class Boston(SimpleCSVLocalDatasetLoader): + X_file = 'boston_house_prices.csv' + feature_names = np.array([ + 'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', + 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT']) + target_names = np.array(['MEDV']) + + +class SampleImages(LocalDatasetLoader): + X_file = np.array(['china.jpg', 'flower.jpg']) + feature_names = np.array([]) + target_names = np.array(['china', 'flower']) + _descr_dir = LocalDatasetLoader._images_dir + descr_file = 'README.txt' + + def __init__(self, image_name: str = None): + self.image_name = image_name + self._check_image_name() + + def _check_image_name(self): + if self.image_name and self.image_name not in self.X_file: + msg = 'Cannot find sample image: %s' % self.image_name + raise AttributeError(msg) + + def read_data(self) -> Bunch: + from ..externals._pilutil import imread # import PIL only when needed + image_files = [join(self._images_dir, file) for file in self.X_file] + images = [imread(img) for img in image_files] + return self._make_bunch(None, None, self.target_names, None, images) + + def process(self, bunch: Bunch): + bunch.filenames = bunch.filename # Backwards compatible (0.21) + return bunch def _pkl_filepath(*args, **kwargs): diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 08a6ba29413cf..60741145d7e1e 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -21,6 +21,14 @@ from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_boston from sklearn.datasets import load_wine +from sklearn.datasets import SampleImages # noqa: F401 +from sklearn.datasets import Digits # noqa: F401 +from sklearn.datasets import Diabetes # noqa: F401 +from sklearn.datasets import Linnerud # noqa: F401 +from sklearn.datasets import Iris # noqa: F401 +from sklearn.datasets import BreastCancer # noqa: F401 +from sklearn.datasets import Boston # noqa: F401 +from sklearn.datasets import Wine # noqa: F401 from sklearn.datasets.base import Bunch from sklearn.datasets.tests.test_common import check_return_X_y @@ -29,6 +37,8 @@ from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises +import numpy.testing.utils as np_test_util + def _remove_dir(path): if os.path.isdir(path): @@ -109,6 +119,7 @@ def test_load_files_w_categories_desc_and_encoding( assert_equal(res.data, ["Hello World!\n"]) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_files_wo_load_content( test_category_dir_1, test_category_dir_2, load_files_root): res = load_files(load_files_root, load_content=False) @@ -118,6 +129,7 @@ def test_load_files_wo_load_content( assert_equal(res.get('data'), None) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_sample_images(): try: res = load_sample_images() @@ -128,6 +140,7 @@ def test_load_sample_images(): warnings.warn("Could not load sample images, PIL is not available.") +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_digits(): digits = load_digits() assert_equal(digits.data.shape, (1797, 64)) @@ -137,12 +150,14 @@ def test_load_digits(): check_return_X_y(digits, partial(load_digits)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_digits_n_class_lt_10(): digits = load_digits(9) assert_equal(digits.data.shape, (1617, 64)) assert_equal(numpy.unique(digits.target).size, 9) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_sample_image(): try: china = load_sample_image('china.jpg') @@ -152,6 +167,7 @@ def test_load_sample_image(): warnings.warn("Could not load sample images, PIL is not available.") +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_missing_sample_image_error(): if pillow_installed: assert_raises(AttributeError, load_sample_image, @@ -160,6 +176,7 @@ def test_load_missing_sample_image_error(): warnings.warn("Could not load sample images, PIL is not available.") +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_diabetes(): res = load_diabetes() assert_equal(res.data.shape, (442, 10)) @@ -171,6 +188,7 @@ def test_load_diabetes(): check_return_X_y(res, partial(load_diabetes)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_linnerud(): res = load_linnerud() assert_equal(res.data.shape, (20, 3)) @@ -184,6 +202,7 @@ def test_load_linnerud(): check_return_X_y(res, partial(load_linnerud)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_iris(): res = load_iris() assert_equal(res.data.shape, (150, 4)) @@ -196,6 +215,7 @@ def test_load_iris(): check_return_X_y(res, partial(load_iris)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_wine(): res = load_wine() assert_equal(res.data.shape, (178, 13)) @@ -207,6 +227,7 @@ def test_load_wine(): check_return_X_y(res, partial(load_wine)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_breast_cancer(): res = load_breast_cancer() assert_equal(res.data.shape, (569, 30)) @@ -219,6 +240,7 @@ def test_load_breast_cancer(): check_return_X_y(res, partial(load_breast_cancer)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_boston(): res = load_boston() assert_equal(res.data.shape, (506, 13)) @@ -259,7 +281,90 @@ def test_bunch_pickle_generated_with_0_16_and_read_with_0_17(): assert_equal(bunch_from_pkl['key'], 'changed') +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_bunch_dir(): # check that dir (important for autocomplete) shows attributes data = load_iris() assert "data" in dir(data) + + +@pytest.mark.parametrize('test_loader,expected_target_dtype', [ + ('Iris', 'int'), + ('Boston', 'float'), + ('BreastCancer', 'int'), + ('Digits', 'int'), + ('Diabetes', 'int'), + ('Linnerud', 'float'), + ('Wine', 'int') +]) +def test_dataset_loader_dtype(test_loader, expected_target_dtype): + assert eval(test_loader)().load().target.dtype == expected_target_dtype + + +@pytest.mark.parametrize('test_loader', [ + 'Iris', + 'Boston', + 'BreastCancer', + 'Digits', + 'Diabetes', + 'Linnerud', + 'Wine' +]) +def test_dataset_loader_bunch_paths(test_loader): + bunch = eval(test_loader)().load() + paths = bunch.filename, bunch.data_filename, bunch.target_filename + assert all(list(map(os.path.exists, paths))) is True + + +@pytest.mark.parametrize('test_loader,exp_features,exp_targets,exp_n', [ + ('Iris', 4, 3, 150), + ('Boston', 13, 1, 506), + ('BreastCancer', 30, 2, 569), + ('Digits', 64, 10, 1797), + ('Diabetes', 10, 1, 442), + ('Linnerud', 3, 3, 20), + ('Wine', 13, 3, 178)]) +def test_dataset_loader_shape(test_loader, exp_features, + exp_targets, exp_n): + bunch = eval(test_loader)().load() + n_features, m_features = bunch.data.shape[:2] + n_targets, m_targets = bunch.target.shape[0], bunch.target_names.size + assert (m_features == exp_features) and \ + (m_targets == exp_targets) and \ + (n_features == n_targets == exp_n) + + +@pytest.mark.parametrize('test_loader', [ + 'Iris', + 'Boston', + 'BreastCancer', + 'Digits', + 'Diabetes', + 'Linnerud', + 'Wine' +]) +def test_dataset_loader_check_nan(test_loader): + bunch = eval(test_loader)().load() + data, target = bunch.data, bunch.target + np_test_util.assert_equal(numpy.isnan(data).any(), False) + np_test_util.assert_equal(numpy.isnan(target).any(), False) + + +def test_load_data_deprecated(): + from ..base import load_data + iris_path = Iris().local_data_paths['X'] + pytest.deprecated_call(load_data, iris_path) + + +@pytest.mark.parametrize('test_deprecated_fun', [ + 'load_wine', + 'load_iris', + 'load_digits', + 'load_diabetes', + 'load_boston', + 'load_breast_cancer', + 'load_linnerud', + 'load_sample_images' +]) +def test_functional_load_deprecated(test_deprecated_fun): + pytest.deprecated_call(eval(test_deprecated_fun))