8000 MAINT Parameters validation for datasets.make_multilabel_classificati… · thomasjpfan/scikit-learn@e3d1f9a · GitHub
[go: up one dir, main page]

Skip to content

Commit e3d1f9a

Browse files
author
Théophile Baranger
authored
MAINT Parameters validation for datasets.make_multilabel_classification (scikit-learn#25920)
1 parent 01f8d34 commit e3d1f9a

File tree

3 files changed

+15
-26
lines changed

3 files changed

+15
-26
lines changed

sklearn/datasets/_samples_generator.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,20 @@ def make_classification(
309309
return X, y
310310

311311

312+
@validate_params(
313+
{
314+
"n_samples": [Interval(Integral, 1, None, closed="left")],
315+
"n_features": [Interval(Integral, 1, None, closed="left")],
316+
"n_classes": [Interval(Integral, 1, None, closed="left")],
317+
"n_labels": [Interval(Integral, 0, None, closed="left")],
318+
"length": [Interval(Integral, 1, None, closed="left")],
319+
"allow_unlabeled": ["boolean"],
320+
"sparse": ["boolean"],
321+
"return_indicator": [StrOptions({"dense", "sparse"}), "boolean"],
322+
"return_distributions": ["boolean"],
323+
"random_state": ["random_state"],
324+
}
325+
)
312326
def make_multilabel_classification(
313327
n_samples=100,
314328
n_features=20,
@@ -398,18 +412,6 @@ def make_multilabel_classification(
398412
The probability of each feature being drawn given each class.
399413
Only returned if ``return_distributions=True``.
400414
"""
401-
if n_classes < 1:
402-
raise ValueError(
403-
"'n_classes' should be an integer greater than 0. Got {} instead.".format(
404-
n_classes
405-
)
406-
)
407-
if length < 1:
408-
raise ValueError(
409-
"'length' should be an integer greater than 0. Got {} instead.".format(
410-
length
411-
)
412-
)
413415

414416
generator = check_random_state(random_state)
415417
p_c = generator.uniform(size=n_classes)
@@ -469,8 +471,6 @@ def sample_example():
469471
if return_indicator in (True, "sparse", "dense"):
470472
lb = MultiLabelBinarizer(sparse_output=(return_indicator == "sparse"))
471473
Y = lb.fit([range(n_classes)]).transform(Y)
472-
elif return_indicator is not False:
473-
raise ValueError("return_indicator must be either 'sparse', 'dense' or False.")
474474
if return_distributions:
475475
return X, Y, p_c, p_w_c
476476
return X, Y

sklearn/datasets/tests/test_samples_generator.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -283,18 +283,6 @@ def test_make_multilabel_classification_return_indicator_sparse():
283283
assert sp.issparse(Y)
284284

285285

286-
@pytest.mark.parametrize(
287-
"params, err_msg",
288-
[
289-
({"n_classes": 0}, "'n_classes' should be an integer"),
290-
({"length": 0}, "'length' should be an integer"),
291-
],
292-
)
293-
def test_make_multilabel_classification_valid_arguments(params, err_msg):
294-
with pytest.raises(ValueError, match=err_msg):
295-
make_multilabel_classification(**params)
296-
297-
298286
def test_make_hastie_10_2():
299287
X, y = make_hastie_10_2(n_samples=100, random_state=0)
300288
assert X.shape == (100, 10), "X shape mismatch"

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _check_function_param_validation(
133133
"sklearn.datasets.make_classification",
134134
"sklearn.datasets.make_friedman1",
135135
"sklearn.datasets.make_low_rank_matrix",
136+
"sklearn.datasets.make_multilabel_classification",
136137
"sklearn.datasets.make_regression",
137138
"sklearn.datasets.make_sparse_coded_signal",
138139
"sklearn.decomposition.sparse_encode",

0 commit comments

Comments
 (0)
0