8000 FIX MaxAbsScaler on sparse matrices with 1 row · scikit-learn/scikit-learn@00996a2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 00996a2

Browse files
Jeffrey04ogrisel
authored andcommitted
FIX MaxAbsScaler on sparse matrices with 1 row
1 parent 76e0273 commit 00996a2

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

sklearn/preprocessing/data.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -788,10 +788,7 @@ def transform(self, X, y=None):
788788
warnings.warn(DEPRECATION_MSG_1D, DeprecationWarning)
789789

790790
if sparse.issparse(X):
791-
if X.shape[0] == 1:
792-
inplace_row_scale(X, 1.0 / self.scale_)
793-
else:
794-
inplace_column_scale(X, 1.0 / self.scale_)
791+
inplace_column_scale(X, 1.0 / self.scale_)
795792
else:
796793
X /= self.scale_
797794
return X
@@ -811,10 +808,7 @@ def inverse_transform(self, X):
811808
warnings.warn(DEPRECATION_MSG_1D, DeprecationWarning)
812809

813810
if sparse.issparse(X):
814-
if X.shape[0] == 1:
815-
inplace_row_scale(X, self.scale_)
816-
else:
817-
inplace_column_scale(X, self.scale_)
811+
inplace_column_scale(X, self.scale_)
818812
else:
819813
X *= self.scale_
820814
return X

sklearn/preprocessing/tests/test_data.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,18 @@ def test_maxabs_scaler_large_negative_value():
937937
assert_array_almost_equal(X_trans, X_expected)
938938

939939

940+
def test_maxabs_scaler_transform_one_row_csr():
941+
"""Check MaxAbsScaler on transforming csr matrix with one row"""
942+
X = sparse.csr_matrix([[0.5, 1., 1.]])
943+
scaler = MaxAbsScaler()
944+
scaler = scaler.fit(X)
945+
X_trans = scaler.transform(X)
946+
X_expected = sparse.csr_matrix([[1., 1., 1.]])
947+
assert_array_almost_equal(X_trans.toarray(), X_expected.toarray())
948+
X_scaled_back = scaler.inverse_transform(X_trans)
949+
assert_array_almost_equal(X.toarray(), X_scaled_back.toarray())
950+
951+
940952
@ignore_warnings
941953
def test_deprecation_minmax_scaler():
942954
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)
0