8000 MAINT Parameters validation for sklearn.datasets.fetch_rcv1 (#26126) · scikit-learn/scikit-learn@c8fb561 · GitHub
[go: up one dir, main page]

Skip to content

Commit c8fb561

Browse files
authored
MAINT Parameters validation for sklearn.datasets.fetch_rcv1 (#26126)
1 parent 24f68c2 commit c8fb561

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

sklearn/datasets/_rcv1.py

+11
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ._svmlight_format_io import load_svmlight_files
2727
from ..utils import shuffle as shuffle_
2828
from ..utils import Bunch
29+
from ..utils._param_validation import validate_params, StrOptions
2930

3031

3132
# The original vectorized data can be found at:
@@ -76,6 +77,16 @@
7677
logger = logging.getLogger(__name__)
7778

7879

80+
@validate_params(
81+
{
82+
"data_home": [str, None],
83+
"subset": [StrOptions({"train", "test", "all"})],
84+
"download_if_missing": ["boolean"],
85+
"random_state": ["random_state"],
86+
"shuffle": ["boolean"],
87+
"return_X_y": ["boolean"],
88+
}
89+
)
7990
def fetch_rcv1(
8091
*,
8192
data_home=None,

sklearn/tests/test_public_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _check_function_param_validation(
127127
"sklearn.datasets.fetch_lfw_pairs",
128128
"sklearn.datasets.fetch_lfw_people",
129129
"sklearn.datasets.fetch_olivetti_faces",
130+
"sklearn.datasets.fetch_rcv1",
130131
"sklearn.datasets.load_svmlight_file",
131132
"sklearn.datasets.load_svmlight_files",
132133
"sklearn.datasets.make_biclusters",

0 commit comments

Comments
 (0)
0