8000 MAINT Parameters validation for sklearn.preprocessing.normalize (#26069) · thomasjpfan/scikit-learn@fb52671 · GitHub
[go: up one dir, main page]

Skip to content

Commit fb52671

Browse files
MAINT Parameters validation for sklearn.preprocessing.normalize (scikit-learn#26069)
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent 310c707 commit fb52671

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

sklearn/preprocessing/_data.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,6 +1777,15 @@ def robust_scale(
17771777
return X
17781778

17791779

1780+
@validate_params(
1781+
{
1782+
"X": ["array-like", "sparse matrix"],
1783+
"norm": [StrOptions({"l1", "l2", "max"})],
1784+
"axis": [Options(Integral, {0, 1})],
1785+
"copy": ["boolean"],
1786+
"return_norm": ["boolean"],
1787+
}
1788+
)
17801789
def normalize(X, norm="l2", *, axis=1, copy=True, return_norm=False):
17811790
"""Scale input vectors individually to unit norm (vector length).
17821791
@@ -1826,15 +1835,10 @@ def normalize(X, norm="l2", *, axis=1, copy=True, return_norm=False):
18261835
see :ref:`examples/preprocessing/plot_all_scaling.py
18271836
<sphx_glr_auto_examples_preprocessing_plot_all_scaling.py>`.
18281837
"""
1829-
if norm not in ("l1", "l2", "max"):
1830-
raise ValueError("'%s' is not a supported norm" % norm)
1831-
18321838
if axis == 0:
18331839
sparse_format = "csc"
1834-
elif axis == 1:
1840+
else: # axis == 1:
18351841
sparse_format = "csr"
1836-
else:
1837-
raise ValueError("'%d' is not a supported axis" % axis)
18381842

18391843
X = check_array(
18401844
X,

sklearn/preprocessing/tests/test_data.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,10 +1988,6 @@ def test_normalize():
19881988
# Only tests functionality not used by the tests for Normalizer.
19891989
X = np.random.RandomState(37).randn(3, 2)
19901990
assert_array_equal(normalize(X, copy=False), normalize(X.T, axis=0, copy=False).T)
1991-
with pytest.raises(ValueError):
1992-
normalize([[0]], axis=2)
1993-
with pytest.raises(ValueError):
1994-
normalize([[0]], norm="l3")
19951991

19961992
rs = np.random.RandomState(0)
19971993
X_dense = rs.randn(10, 5)

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def _check_function_param_validation(
228228
"sklearn.preprocessing.binarize",
229229
"sklearn.preprocessing.label_binarize",
230230
"sklearn.preprocessing.maxabs_scale",
231+
"sklearn.preprocessing.normalize",
231232
"sklearn.preprocessing.scale",
232233
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
233234
"sklearn.svm.l1_min_c",

0 commit comments

Comments
 (0)
0