8000 MAINT Parameters validation for sklearn.preprocessing.normalize (#26069) · scikit-learn/scikit-learn@fb52671 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fb52671

Browse files
MAINT Parameters validation for sklearn.preprocessing.normalize (#26069)
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent 310c707 commit fb52671

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

sklearn/preprocessing/_data.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,6 +1777,15 @@ def robust_scale(
17771777
return X
17781778

17791779

1780+
@validate_params(
1781+
{
1782+
"X": ["array-like", "sparse matrix"],
1783+
"norm": [StrOptions({"l1", "l2", "max"})],
1784+
"axis": [Options(Integral, {0, 1})],
1785+
"copy": ["boolean"],
1786+
"return_norm": ["boolean"],
1787+
}
1788+
)
17801789
def normalize(X, norm="l2", *, axis=1, copy=True, return_norm=False):
17811790
"""Scale input vectors individually to unit norm (vector length).
17821791
@@ -1826,15 +1835,10 @@ def normalize(X, norm="l2", *, axis=1, copy=True, return_norm=False):
18261835
see :ref:`examples/preprocessing/plot_all_scaling.py
18271836
<sphx_glr_auto_examples_preprocessing_plot_all_scaling.py>`.
18281837
"""
1829-
if norm not in ("l1", "l2", "max"):
1830-
raise ValueError("'%s' is not a supported norm" % norm)
1831-
18321838
if axis == 0:
18331839
sparse_format = "csc"
1834-
elif axis == 1:
1840+
else: # axis == 1:
18351841
sparse_format = "csr"
1836-
else:
1837-
raise ValueError("'%d' is not a supported axis" % axis)
18381842

18391843
X = check_array(
18401844
X,

sklearn/preprocessing/tests/test_data.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,10 +1988,6 @@ def test_normalize():
19881988
# Only tests functionality not used by the tests for Normalizer.
19891989
X = np.random.RandomState(37).randn(3, 2)
19901990
assert_array_equal(normalize(X, copy=False), normalize(X.T, axis=0, copy=False).T)
1991-
with pytest.raises(ValueError):
1992-
normalize([[0]], axis=2)
1993-
with pytest.raises(ValueError):
1994-
normalize([[0]], norm="l3")
19951991

19961992
rs = np.random.RandomState(0)
19971993
X_dense = rs.randn(10, 5)

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def _check_function_param_validation(
228228
"sklearn.preprocessing.binarize",
229229
"sklearn.preprocessing.label_binarize",
230230
"sklearn.preprocessing.maxabs_scale",
231+
"sklearn.preprocessing.normalize",
231232
"sklearn.preprocessing.scale",
232233
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
233234
"sklearn.svm.l1_min_c",

0 commit comments

Comments
 (0)
0