|
46 | 46 | from ..feature_extraction.text import CountVectorizer |
47 | 47 | from .. import preprocessing |
48 | 48 | from ..utils import check_random_state, Bunch |
| 49 | +from ..utils._param_validation import StrOptions, validate_params |
49 | 50 |
|
50 | 51 | logger = logging.getLogger(__name__) |
51 | 52 |
|
@@ -149,6 +150,18 @@ def strip_newsgroup_footer(text): |
149 | 150 | return text |
150 | 151 |
|
151 | 152 |
|
| 153 | +@validate_params( |
| 154 | + { |
| 155 | + "data_home": [str, None], |
| 156 | + "subset": [StrOptions({"train", "test", "all"})], |
| 157 | + "categories": ["array-like", None], |
| 158 | + "shuffle": ["boolean"], |
| 159 | + "random_state": ["random_state"], |
| 160 | + "remove": [tuple], |
| 161 | + "download_if_missing": ["boolean"], |
| 162 | + "return_X_y": ["boolean"], |
| 163 | + } |
| 164 | +) |
152 | 165 | def fetch_20newsgroups( |
153 | 166 | *, |
154 | 167 | data_home=None, |
@@ -287,10 +300,6 @@ def fetch_20newsgroups( |
287 | 300 | data.data = data_lst |
288 | 301 | data.target = np.array(target) |
289 | 302 | data.filenames = np.array(filenames) |
290 | | - else: |
291 | | - raise ValueError( |
292 | | - "subset can only be 'train', 'test' or 'all', got '%s'" % subset |
293 | | - ) |
294 | 303 |
|
295 | 304 | fdescr = load_descr("twenty_newsgroups.rst") |
296 | 305 |
|
|
0 commit comments