8000 MAINT Parameters validation for datasets.make_blobs (#25983) · Veghit/scikit-learn@cf695ed · GitHub
[go: up one dir, main page]

Skip to content

Commit cf695ed

Browse files
Théophile Barangerglemaitre
authored andcommitted
MAINT Parameters validation for datasets.make_blobs (scikit-learn#25983)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent d7aa066 commit cf695ed

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

sklearn/datasets/_samples_generator.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,18 @@ def make_moons(n_samples=100, *, shuffle=True, noise=None, random_state=None):
855855
return X, y
856856

857857

858+
@validate_params(
859+
{
860+
"n_samples": [Interval(Integral, 1, None, closed="left"), "array-like"],
861+
"n_features": [Interval(Integral, 1, None, closed="left")],
862+
"centers": [Interval(Integral, 1, None, closed="left"), "array-like", None],
863+
"cluster_std": [Interval(Real, 0, None, closed="left"), "array-like"],
864+
"center_box": [tuple],
865+
"shuffle": ["boolean"],
866+
"random_state": ["random_state"],
867+
"return_centers": ["boolean"],
868+
}
869+
)
858870
def make_blobs(
859871
n_samples=100,
860872
n_features=2,
@@ -884,7 +896,7 @@ def make_blobs(
884896
n_features : int, default=2
885897
The number of features for each sample.
886898
887-
centers : int or ndarray of shape (n_centers, n_features), default=None
899+
centers : int or array-like of shape (n_centers, n_features), default=None
888900
The number of centers to generate, or the fixed center locations.
889901
If n_samples is an int and centers is None, 3 centers are generated.
890902
If n_samples is array-like, centers must be
@@ -967,22 +979,19 @@ def make_blobs(
967979
centers = generator.uniform(
968980
center_box[0], center_box[1], size=(n_centers, n_features)
969981
)
970-
try:
971-
assert len(centers) == n_centers
972-
except TypeError as e:
982+
if not isinstance(centers, Iterable):
973983
raise ValueError(
974984
"Parameter `centers` must be array-like. Got {!r} instead".format(
975985
centers
976986
)
977-
) from e
978-
except AssertionError as e:
987+
)
988+
if len(centers) != n_centers:
979989
raise ValueError(
980990
"Length of `n_samples` not consistent with number of "
981991
f"centers. Got n_samples = {n_samples} and centers = {centers}"
982-
) from e
983-
else:
984-
centers = check_array(centers)
985-
n_features = centers.shape[1]
992+
)
993+
centers = check_array(centers)
994+
n_features = centers.shape[1]
986995

987996
# stds: if cluster_std is given as list, it must be consistent
988997
# with the n_centers

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _check_function_param_validation(
130130
"sklearn.datasets.load_svmlight_file",
131131
"sklearn.datasets.load_svmlight_files",
132132
"sklearn.datasets.make_biclusters",
133+
"sklearn.datasets.make_blobs",
133134
"sklearn.datasets.make_checkerboard",
134135
"sklearn.datasets.make_circles",
135136
"sklearn.datasets.make_classification",

0 commit comments

Comments
 (0)
0