8000 MAINT Parameters validation for datasets.make_multilabel_classification · Pull Request #25920 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
8000
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions sklearn/datasets/_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,20 @@ def make_classification(
return X, y


@validate_params(
{
"n_samples": [Interval(Integral, 1, None, closed="left")],
"n_features": [Interval(Integral, 1, None, closed="left")],
"n_classes": [Interval(Integral, 1, None, closed="left")],
"n_labels": [Interval(Integral, 0, None, closed="left")],
"length": [Interval(Integral, 1, None, closed="left")],
"allow_unlabeled": ["boolean"],
"sparse": ["boolean"],
"return_indicator": [StrOptions({"dense", "sparse"}), "boolean"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we'd have False as a public valid option and True as a hidden valid option but let's not bother.

"return_distributions": ["boolean"],
"random_state": ["random_state"],
}
)
def make_multilabel_classification(
n_samples=100,
n_features=20,
Expand Down Expand Up @@ -398,18 +412,6 @@ def make_multilabel_classification(
The probability of each feature being drawn given each class.
Only returned if ``return_distributions=True``.
"""
if n_classes < 1:
raise ValueError(
"'n_classes' should be an integer greater than 0. Got {} instead.".format(
n_classes
)
)
if length < 1:
raise ValueError(
"'length' should be an integer greater than 0. Got {} instead.".format(
length
)
)

generator = check_random_state(random_state)
p_c = generator.uniform(size=n_classes)
Expand Down Expand Up @@ -469,8 +471,6 @@ def sample_example():
if return_indicator in (True, "sparse", "dense"):
lb = MultiLabelBinarizer(sparse_output=(return_indicator == "sparse"))
Y = lb.fit([range(n_classes)]).transform(Y)
elif return_indicator is not False:
raise ValueError("return_indicator must be either 'sparse', 'dense' or False.")
if return_distributions:
return X, Y, p_c, p_w_c
return X, Y
Expand Down
12 changes: 0 additions & 12 deletions sklearn/datasets/tests/test_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,18 +283,6 @@ def test_make_multilabel_classification_return_indicator_sparse():
assert sp.issparse(Y)


@pytest.mark.parametrize(
"params, err_msg",
[
({"n_classes": 0}, "'n_classes' should be an integer"),
({"length": 0}, "'length' should be an integer"),
],
)
def test_make_multilabel_classification_valid_arguments(params, err_msg):
with pytest.raises(ValueError, match=err_msg):
make_multilabel_classification(**params)


def test_make_hastie_10_2():
X, y = make_hastie_10_2(n_samples=100, random_state=0)
assert X.shape == (100, 10), "X shape mismatch"
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _check_function_param_validation(
"sklearn.datasets.make_classification",
"sklearn.datasets.make_friedman1",
"sklearn.datasets.make_low_rank_matrix",
"sklearn.datasets.make_multilabel_classification",
"sklearn.datasets.make_regression",
"sklearn.datasets.make_sparse_coded_signal",
"sklearn.decomposition.sparse_encode",
Expand Down
0