8000 ENH Adds support for pandas dataframe with only sparse arrays (#16728) · scikit-learn/scikit-learn@fa1ea2a · GitHub
[go: up one dir, main page]

Skip to content

Commit fa1ea2a

Browse files
authored
ENH Adds support for pandas dataframe with only sparse arrays (#16728)
1 parent fb2b01f commit fa1ea2a

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

doc/whats_new/v0.23.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,10 @@ Changelog
413413
pandas sparse DataFrame.
414414
:pr:`16021` by :user:`Rushabh Vasani <rushabh-v>`.
415415

416+
- |Enhancement| :func:`utils.validation.check_array` now constructs a sparse
417+
matrix from a pandas DataFrame that contains only `SparseArray`s.
418+
:pr:`16728` by `Thomas Fan`_.
419+
416420
:mod:`sklearn.cluster`
417421
......................
418422

sklearn/linear_model/tests/test_base.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,16 +212,30 @@ def test_linear_regression_pd_sparse_dataframe_warning():
212212
# restrict the pd versions < '0.24.0' as they have a bug in is_sparse func
213213
if LooseVersion(pd.__version__) < '0.24.0':
214214
pytest.skip("pandas 0.24+ required.")
215-
df = pd.DataFrame()
216-
for col in range(4):
215+
216+
# Warning is raised only when some of the columns is sparse
217+
df = pd.DataFrame({'0': np.random.randn(10)})
218+
for col in range(1, 4):
217219
arr = np.random.randn(10)
218220
arr[:8] = 0
219-
df[str(col)] = pd.arrays.SparseArray(arr, fill_value=0)
221+
# all columns but the first column is sparse
222+
if col != 0:
223+
arr = pd.arrays.SparseArray(arr, fill_value=0)
224+
df[str(col)] = arr
225+
220226
msg = "pandas.DataFrame with sparse columns found."
221227
with pytest.warns(UserWarning, match=msg):
222228
reg = LinearRegression()
223229
reg.fit(df.iloc[:, 0:2], df.iloc[:, 3])
224230

231+
# does not warn when the whole dataframe is sparse
232+
df['0'] = pd.arrays.SparseArray(df['0'], fill_value=0)
233+
assert hasattr(df, "sparse")
234+
235+
with pytest.warns(None) as record:
236+
reg.fit(df.iloc[:, 0:2], df.iloc[:, 3])
237+
assert not record
238+
225239

226240
def test_preprocess_data():
227241
n_samples = 200

sklearn/utils/tests/test_validation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,3 +1153,22 @@ def test_check_fit_params(indices):
11531153
result['sparse-col'],
11541154
_safe_indexing(fit_params['sparse-col'], indices_)
11551155
)
1156+
1157+
1158+
@pytest.mark.parametrize('sp_format', [True, 'csr', 'csc', 'coo', 'bsr'])
1159+
def test_check_sparse_pandas_sp_format(sp_format):
1160+
# check_array converts pandas dataframe with only sparse arrays into
1161+
# sparse matrix
1162+
pd = pytest.importorskip("pandas")
1163+
sp_mat = _sparse_random_matrix(10, 3)
1164+
1165+
sdf = pd.DataFrame.sparse.from_spmatrix(sp_mat)
1166+
result = check_array(sdf, accept_sparse=sp_format)
1167+
1168+
if sp_format is True:
1169+
# by default pandas converts to coo when accept_sparse is True
1170+
sp_format = 'coo'
1171+
1172+
assert sp.issparse(result)
1173+
assert result.format == sp_format
1174+
assert_allclose_dense_sparse(sp_mat, result)

sklearn/utils/validation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,12 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
451451
# DataFrame), and store them. If not, store None.
452452
dtypes_orig = None
453453
if hasattr(array, "dtypes") and hasattr(array.dtypes, '__array__'):
454-
# throw warning if pandas dataframe is sparse
454+
# throw warning if columns are sparse. If all columns are sparse, then
455+
# array.sparse exists and sparsity will be perserved (later).
455456
with suppress(ImportError):
456457
from pandas.api.types import is_sparse
457-
if array.dtypes.apply(is_sparse).any():
458+
if (not hasattr(array, 'sparse') and
459+
array.dtypes.apply(is_sparse).any()):
458460
warnings.warn(
459461
"pandas.DataFrame with sparse columns found."
460462
"It will be converted to a dense numpy array."
@@ -498,6 +500,11 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
498500
estimator_name = "Estimator"
499501
context = " by %s" % estimator_name if estimator is not None else ""
500502

503+
# When all dataframe columns are sparse, convert to a sparse array
504+
if hasattr(array, 'sparse') and array.ndim > 1:
505+
# DataFrame.sparse only supports `to_coo`
506+
array = array.sparse.to_coo()
507+
501508
if sp.issparse(array):
502509
_ensure_no_complex_data(array)
503510
array = _ensure_sparse_format(array, accept_sparse=accept_sparse,

0 commit comments

Comments
 (0)
0