8000 MAINT Parameters validation for sklearn.preprocessing.quantile_transf… · scikit-learn/scikit-learn@b2c3881 · GitHub
[go: up one dir, main page]

Skip to content

Commit b2c3881

Browse files
2357juanjeremiedbb
andauthored
MAINT Parameters validation for sklearn.preprocessing.quantile_transform (#26144)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 72610f1 commit b2c3881

File tree

3 files changed

+11
-16
lines changed

3 files changed

+11
-16
lines changed

sklearn/preprocessing/_data.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2852,6 +2852,9 @@ def _more_tags(self):
28522852
return {"allow_nan": True}
28532853

28542854

2855+
@validate_params(
2856+
{"X": ["array-like", "sparse matrix"], "axis": [Options(Integral, {0, 1})]}
2857+
)
28552858
def quantile_transform(
28562859
X,
28572860
*,
@@ -2986,13 +2989,10 @@ def quantile_transform(
29862989
copy=copy,
29872990
)
29882991
if axis == 0:
2989-
return n.fit_transform(X)
2990-
elif axis == 1:
2991-
return n.fit_transform(X.T).T
2992-
else:
2993-
raise ValueError(
2994-
"axis should be either equal to 0 or 1. Got axis={}".format(axis)
2995-
)
2992+
X = n.fit_transform(X)
2993+
else: # axis == 1
2994+
X = n.fit_transform(X.T).T
2995+
return X
29962996

29972997

29982998
class PowerTransformer(OneToOneFeatureMixin, TransformerMixin, BaseEstimator):

sklearn/preprocessing/tests/test_data.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,15 +2246,6 @@ def test_fit_cold_start():
22462246
scaler.fit_transform(X_2d)
22472247

22482248

2249-
def test_quantile_transform_valid_axis():
2250-
X = np.array([[0, 25, 50, 75, 100], [2, 4, 6, 8, 10], [2.6, 4.1, 2.3, 9.5, 0.1]])
2251-
2252-
with pytest.raises(
2253-
ValueError, match="axis should be either equal to 0 or 1. Got axis=2"
2254-
):
2255-
quantile_transform(X.T, axis=2)
2256-
2257-
22582249
@pytest.mark.parametrize("method", ["box-cox", "yeo-johnson"])
22592250
def test_power_transformer_notfitted(method):
22602251
pt = PowerTransformer(method=method)

sklearn/tests/test_public_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,10 @@ def test_function_param_validation(func_module):
289289
("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"),
290290
("sklearn.preprocessing.minmax_scale", "sklearn.preprocessing.MinMaxScaler"),
291291
("sklearn.preprocessing.power_transform", "sklearn.preprocessing.PowerTransformer"),
292+
(
293+
"sklearn.preprocessing.quantile_transform",
294+
"sklearn.preprocessing.QuantileTransformer",
295+
),
292296
("sklearn.preprocessing.robust_scale", "sklearn.preprocessing.RobustScaler"),
293297
]
294298

0 commit comments

Comments
 (0)
0