diff --git a/sklearn/datasets/_lfw.py b/sklearn/datasets/_lfw.py index 7252c050bef3c..c6f1a5f9a90c8 100644 --- a/sklearn/datasets/_lfw.py +++ b/sklearn/datasets/_lfw.py @@ -10,7 +10,8 @@ from os import listdir, makedirs, remove from os.path import join, exists, isdir - +from ..utils._param_validation import validate_params, Interval, Hidden +from numbers import Integral, Real import logging import numpy as np @@ -231,6 +232,18 @@ def _fetch_lfw_people( return faces, target, target_names +@validate_params( + { + "data_home": [str, None], + "funneled": ["boolean"], + "resize": [Interval(Real, 0, None, closed="neither"), None], + "min_faces_per_person": [Interval(Integral, 0, None, closed="left"), None], + "color": ["boolean"], + "slice_": [tuple, Hidden(None)], + "download_if_missing": ["boolean"], + "return_X_y": ["boolean"], + } +) def fetch_lfw_people( *, data_home=None, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 31aeb37c5e536..22b0aa2653678 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -108,6 +108,7 @@ def _check_function_param_validation( "sklearn.datasets.fetch_california_housing", "sklearn.datasets.fetch_covtype", "sklearn.datasets.fetch_kddcup99", + "sklearn.datasets.fetch_lfw_people", "sklearn.datasets.fetch_olivetti_faces", "sklearn.datasets.load_svmlight_file", "sklearn.datasets.load_svmlight_files",