8000 MAINT Use `scipy.sparse.isspmatrix_*` (#26420) · scikit-learn/scikit-learn@41b0bd8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 41b0bd8

Browse files
authored
MAINT Use scipy.sparse.isspmatrix_* (#26420)
1 parent b32e5c7 commit 41b0bd8

File tree

8 files changed

+40
-34
lines changed

8 files changed

+40
-34
lines changed

sklearn/model_selection/tests/test_split.py

Lines changed: 8 additions & 3 deletions
E7F5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
import pytest
44
import re
55
import numpy as np
6-
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix
6+
from scipy.sparse import (
7+
coo_matrix,
8+
csc_matrix,
9+
csr_matrix,
10+
isspmatrix_csr,
11+
)
712
from scipy import stats
813
from scipy.special import comb
914
from itertools import combinations
@@ -1327,8 +1332,8 @@ def test_train_test_split_sparse():
13271332
for InputFeatureType in sparse_types:
13281333
X_s = InputFeatureType(X)
13291334
X_train, X_test = train_test_split(X_s)
1330-
assert isinstance(X_train, csr_matrix)
1331-
assert isinstance(X_test, csr_matrix)
1335+
assert isspmatrix_csr(X_train)
1336+
assert isspmatrix_csr(X_test)
13321337

13331338

13341339
def test_train_test_split_mock_pandas():

sklearn/preprocessing/tests/test_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,7 +1851,7 @@ def test_normalizer_l1():
18511851
X_norm = normalizer = Normalizer(norm="l2", copy=False).transform(X)
18521852

18531853
assert X_norm is not X
1854-
assert isinstance(X_norm, sparse.csr_matrix)
1854+
assert sparse.isspmatrix_csr(X_norm)
18551855

18561856
X_norm = toarray(X_norm)
18571857
for i in range(3):
@@ -1898,7 +1898,7 @@ def test_normalizer_l2():
18981898
X_norm = normalizer = Normalizer(norm="l2", copy=False).transform(X)
18991899

19001900
assert X_norm is not X
1901-
assert isinstance(X_norm, sparse.csr_matrix)
1901+
assert sparse.isspmatrix_csr(X_norm)
19021902

19031903
X_norm = toarray(X_norm)
19041904
for i in range(3):
@@ -1946,7 +1946,7 @@ def test_normalizer_max():
19461946
X_norm = normalizer = Normalizer(norm="l2", copy=False).transform(X)
19471947

19481948
assert X_norm is not X
1949-
assert isinstance(X_norm, sparse.csr_matrix)
1949+
assert sparse.isspmatrix_csr(X_norm)
19501950

19511951
X_norm = toarray(X_norm)
19521952
for i in range(3):

sklearn/preprocessing/tests/test_polynomial.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def test_polynomial_features_csc_X(deg, include_bias, interaction_only, dtype):
625625
Xt_csc = est.fit_transform(X_csc.astype(dtype))
626626
Xt_dense = est.fit_transform(X.astype(dtype))
627627

628-
assert isinstance(Xt_csc, sparse.csc_matrix)
628+
assert sparse.isspmatrix_csc(Xt_csc)
629629
assert Xt_csc.dtype == Xt_dense.dtype
630630
assert_array_almost_equal(Xt_csc.A, Xt_dense)
631631

@@ -652,7 +652,7 @@ def test_polynomial_features_csr_X(deg, include_bias, interaction_only, dtype):
652652
Xt_csr = est.fit_transform(X_csr.astype(dtype))
653653
Xt_dense = est.fit_transform(X.astype(dtype, copy=False))
654654

655-
assert isinstance(Xt_csr, sparse.csr_matrix)
655+
assert sparse.isspmatrix_csr(Xt_csr)
656656
assert Xt_csr.dtype == Xt_dense.dtype
657657
assert_array_almost_equal(Xt_csr.A, Xt_dense)
658658

@@ -711,7 +711,7 @@ def test_polynomial_features_csr_X_floats(deg, include_bias, interaction_only, d
711711
Xt_csr = est.fit_transform(X_csr.astype(dtype))
712712
Xt_dense = est.fit_transform(X.astype(dtype))
713713

714-
assert isinstance(Xt_csr, sparse.csr_matrix)
714+
assert sparse.isspmatrix_csr(Xt_csr)
715715
assert Xt_csr.dtype == Xt_dense.dtype
716716
assert_array_almost_equal(Xt_csr.A, Xt_dense)
717717

@@ -742,7 +742,7 @@ def test_polynomial_features_csr_X_zero_row(zero_row_index, deg, interaction_onl
742742
Xt_csr = est.fit_transform(X_csr)
743743
Xt_dense = est.fit_transform(X)
744744

745-
assert isinstance(Xt_csr, sparse.csr_matrix)
745+
assert sparse.isspmatrix_csr(Xt_csr)
746746
assert Xt_csr.dtype == Xt_dense.dtype
747747
assert_array_almost_equal(Xt_csr.A, Xt_dense)
748748

@@ -763,7 +763,7 @@ def test_polynomial_features_csr_X_degree_4(include_bias, interaction_only):
763763
Xt_csr = est.fit_transform(X_csr)
764764
Xt_dense = est.fit_transform(X)
765765

766-
assert isinstance(Xt_csr, sparse.csr_matrix)
766+
assert sparse.isspmatrix_csr(X_csr)
767767
assert Xt_csr.dtype == Xt_dense.dtype
768768
assert_array_almost_equal(Xt_csr.A, Xt_dense)
769769

@@ -791,7 +791,7 @@ def test_polynomial_features_csr_X_dim_edges(deg, dim, interaction_only):
791791
Xt_csr = est.fit_transform(X_csr)
792792
Xt_dense = est.fit_transform(X)
793793

794-
assert isinstance(Xt_csr, sparse.csr_matrix)
794+
assert sparse.isspmatrix_csr(Xt_csr)
795795
assert Xt_csr.dtype == Xt_dense.dtype
796796
assert_array_almost_equal(Xt_csr.A, Xt_dense)
797797

sklearn/tree/_splitter.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ from cython cimport final
2020

2121
import numpy as np
2222

23-
from scipy.sparse import csc_matrix
23+
from scipy.sparse import isspmatrix_csc
2424

2525
from ._utils cimport log
2626
from ._utils cimport rand_int
@@ -1041,7 +1041,7 @@ cdef class SparsePartitioner:
10411041
DTYPE_t[::1] feature_values,
10421042
const unsigned char[::1] feature_has_missing,
10431043
):
1044-
if not isinstance(X, csc_matrix):
1044+
if not isspmatrix_csc(X):
10451045
raise ValueError("X should be in csc format")
10461046

10471047
self.samples = samples

sklearn/tree/_tree.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ cnp.import_array()
3232

3333
from scipy.sparse import issparse
3434
from scipy.sparse import csr_matrix
35+
from scipy.sparse import isspmatrix_csr
3536

3637
from ._utils cimport safe_realloc
3738
from ._utils cimport sizet_ptr_to_ndarray
@@ -876,7 +877,7 @@ cdef class Tree:
876877
"""Finds the terminal region (=leaf node) for each sample in sparse X.
877878
"""
878879
# Check input
879-
if not isinstance(X, csr_matrix):
880+
if not isspmatrix_csr(X):
880881
raise ValueError("X should be in csr_matrix format, got %s"
881882
% type(X))
882883

@@ -1004,7 +1005,7 @@ cdef class Tree:
10041005
"""Finds the decision path (=node) for each sample in X."""
10051006

10061007
# Check input
1007-
if not isinstance(X, csr_matrix):
1008+
if not isspmatrix_csr(X):
10081009
raise ValueError("X should be in csr_matrix format, got %s"
10091010
% type(X))
10101011

sklearn/utils/extmath.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def row_norms(X, squared=False):
7272
The row-wise (squared) Euclidean norm of X.
7373
"""
7474
if sparse.issparse(X):
75-
if not isinstance(X, sparse.csr_matrix):
75+
if not sparse.isspmatrix_csr(X):
7676
X = sparse.csr_matrix(X)
7777
norms = csr_row_norms(X)
7878
else:
@@ -425,7 +425,7 @@ def randomized_svd(
425425
>>> U.shape, s.shape, Vh.shape
426426
((3, 2), (2,), (2, 4))
427427
"""
428-
if isinstance(M, (sparse.lil_matrix, sparse.dok_matrix)):
428+
if sparse.isspmatrix_lil(M) or sparse.isspmatrix_dok(M):
429429
warnings.warn(
430430
"Calculating SVD of a {} is expensive. "
431431
"csr_matrix is more efficient.".format(type(M).__name__),

sklearn/utils/multiclass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import warnings
1212

1313
from scipy.sparse import issparse
14-
from scipy.sparse import dok_matrix
15-
from scipy.sparse import lil_matrix
14+
from scipy.sparse import isspmatrix_dok
15+
from scipy.sparse import isspmatrix_lil
1616

1717
import numpy as np
1818

@@ -179,7 +179,7 @@ def is_multilabel(y):
179179
return False
180180

181181
if issparse(y):
182-
if isinstance(y, (dok_matrix, lil_matrix)):
182+
if isspmatrix_dok(y) or isspmatrix_lil(y):
183183
y = y.tocsr()
184184
labels = xp.unique_values(y.data)
185185
return (

sklearn/utils/sparsefuncs.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def mean_variance_axis(X, axis, weights=None, return_sum_weights=False):
103103
"""
104104
_raise_error_wrong_axis(axis)
105105

106-
if isinstance(X, sp.csr_matrix):
106+
if sp.isspmatrix_csr(X):
107107
if axis == 0:
108108
return _csr_mean_var_axis0(
109109
X, weights=weights, return_sum_weights=return_sum_weights
@@ -112,7 +112,7 @@ def mean_variance_axis(X, axis, weights=None, return_sum_weights=False):
112112
return _csc_mean_var_axis0(
113113
X.T, weights=weights, return_sum_weights=return_sum_weights
114114
)
115-
elif isinstance(X, sp.csc_matrix):
115+
elif sp.isspmatrix_csc(X):
116116
if axis == 0:
117117
return _csc_mean_var_axis0(
118118
X, weights=weights, return_sum_weights=return_sum_weights
@@ -187,7 +187,7 @@ def incr_mean_variance_axis(X, *, axis, last_mean, last_var, last_n, weights=Non
187187
"""
188188
_raise_error_wrong_axis(axis)
189189

190-
if not isinstance(X, (sp.csr_matrix, sp.csc_matrix)):
190+
if not (sp.isspmatrix_csr(X) or sp.isspmatrix_csc(X)):
191191
_raise_typeerror(X)
192192

193193
if np.size(last_n) == 1:
@@ -234,9 +234,9 @@ def inplace_column_scale(X, scale):
234234
scale : ndarray of shape (n_features,), dtype={np.float32, np.float64}
235235
Array of precomputed feature-wise values to use for scaling.
236236
"""
237-
if isinstance(X, sp.csc_matrix):
237+
if sp.isspmatrix_csc(X):
238238
inplace_csr_row_scale(X.T, scale)
239-
elif isinstance(X, sp.csr_matrix):
239+
elif sp.isspmatrix_csr(X):
240240
inplace_csr_column_scale(X, scale)
241241
else:
242242
_raise_typeerror(X)
@@ -256,9 +256,9 @@ def inplace_row_scale(X, scale):
256256
scale : ndarray of shape (n_features,), dtype={np.float32, np.float64}
257257
Array of precomputed sample-wise values to use for scaling.
258258
"""
259-
if isinstance(X, sp.csc_matrix):
259+
if sp.isspmatrix_csc(X):
260260
inplace_csr_column_scale(X.T, scale)
261-
elif isinstance(X, sp.csr_matrix):
261+
elif sp.isspmatrix_csr(X):
262262
inplace_csr_row_scale(X, scale)
263263
else:
264264
_raise_typeerror(X)
@@ -372,9 +372,9 @@ def inplace_swap_row(X, m, n):
372372
n : int
373373
Index of the row of X to be swapped.
374374
"""
375-
if isinstance(X, sp.csc_matrix):
375+
if sp.isspmatrix_csc(X):
376376
inplace_swap_row_csc(X, m, n)
377-
elif isinstance(X, sp.csr_matrix):
377+
elif sp.isspmatrix_csr(X):
378378
inplace_swap_row_csr(X, m, n)
379379
else:
380380
_raise_typeerror(X)
@@ -400,9 +400,9 @@ def inplace_swap_column(X, m, n):
400400
m += X.shape[1]
401401
if n < 0:
402402
n += X.shape[1]
403-
if isinstance(X, sp.csc_matrix):
403+
if sp.isspmatrix_csc(X):
404404
inplace_swap_row_csr(X, m, n)
405-
elif isinstance(X, sp.csr_matrix):
405+
elif sp.isspmatrix_csr(X):
406406
inplace_swap_row_csc(X, m, n)
407407
else:
408408
_raise_typeerror(X)
@@ -501,7 +501,7 @@ def min_max_axis(X, axis, ignore_nan=False):
501501
maxs : ndarray of shape (n_features,), dtype={np.float32, np.float64}
502502
Feature-wise maxima.
503503
"""
504-
if isinstance(X, (sp.csr_matrix, sp.csc_matrix)):
504+
if sp.isspmatrix_csr(X) or sp.isspmatrix_csc(X):
505505
if ignore_nan:
506506
return _sparse_nan_min_max(X, axis=axis)
507507
else:
@@ -610,7 +610,7 @@ def csc_median_axis_0(X):
610610
median : ndarray of shape (n_features,)
611611
Median.
612612
"""
613-
if not isinstance(X, sp.csc_matrix):
613+
if not sp.isspmatrix_csc(X):
614614
raise TypeError("Expected matrix of CSC format, got %s" % X.format)
615615

616616
indptr = X.indptr

0 commit comments

Comments
 (0)
0