8000 MAINT Parameters validation for sklearn.datasets.fetch_20newsgroups (… · thomasjpfan/scikit-learn@94abb83 · GitHub
[go: up one dir, main page]

Skip to content

Commit 94abb83

Browse files
author
Théophile Baranger
authored
MAINT Parameters validation for sklearn.datasets.fetch_20newsgroups (scikit-learn#25786)
1 parent ea59b1e commit 94abb83

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

sklearn/datasets/_twenty_newsgroups.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from ..feature_extraction.text import CountVectorizer
4747
from .. import preprocessing
4848
from ..utils import check_random_state, Bunch
49+
from ..utils._param_validation import StrOptions, validate_params
4950

5051
logger = logging.getLogger(__name__)
5152

@@ -149,6 +150,18 @@ def strip_newsgroup_footer(text):
149150
return text
150151

151152

153+
@validate_params(
154+
{
155+
"data_home": [str, None],
156+
"subset": [StrOptions({"train", "test", "all"})],
157+
"categories": ["array-like", None],
158+
"shuffle": ["boolean"],
159+
"random_state": ["random_state"],
160+
"remove": [tuple],
161+
"download_if_missing": ["boolean"],
162+
"return_X_y": ["boolean"],
163+
}
164+
)
152165
def fetch_20newsgroups(
153166
*,
154167
data_home=None,
@@ -287,10 +300,6 @@ def fetch_20newsgroups(
287300
data.data = data_lst
288301
data.target = np.array(target)
289302
data.filenames = np.array(filenames)
290-
else:
291-
raise ValueError(
292-
"subset can only be 'train', 'test' or 'all', got '%s'" % subset
293-
)
294303

295304
fdescr = load_descr("twenty_newsgroups.rst")
296305

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def _check_function_param_validation(
103103
"sklearn.covariance.empirical_covariance",
104104
"sklearn.covariance.shrunk_covariance",
105105
"sklearn.datasets.dump_svmlight_file",
106+
"sklearn.datasets.fetch_20newsgroups",
106107
"sklearn.datasets.fetch_california_housing",
107108
"sklearn.datasets.fetch_covtype",
108109
"sklearn.datasets.fetch_kddcup99",

0 commit comments

Comments
 (0)
0