8000 MNT: Use Polars in test_get_column_indices_interchange (#31095) · scikit-learn/scikit-learn@e146da1 · GitHub
[go: up one dir, main page]

Skip to content

Commit e146da1

Browse files
authored
MNT: Use Polars in test_get_column_indices_interchange (#31095)
1 parent 7505ed5 commit e146da1

File tree

1 file changed

+6
-16
lines changed

1 file changed

+6
-16
lines changed

sklearn/utils/tests/test_indexing.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -449,20 +449,10 @@ def test_get_column_indices_pandas_nonunique_columns_error(key):
449449

450450
def test_get_column_indices_interchange():
451451
"""Check _get_column_indices for edge cases with the interchange"""
452-
pd = pytest.importorskip("pandas", minversion="1.5")
452+
pl = pytest.importorskip("polars")
453453

454-
df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["a", "b", "c"])
455-
456-
# Hide the fact that this is a pandas dataframe to trigger the dataframe protocol
457-
# code path.
458-
class MockDataFrame:
459-
def __init__(self, df):
460-
self._df = df
461-
462-
def __getattr__(self, name):
463-
return getattr(self._df, name)
464-
465-
df_mocked = MockDataFrame(df)
454+
# Polars dataframes go down the interchange path.
455+
df = pl.DataFrame([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"])
466456

467457
key_results = [
468458
(slice(1, None), [1, 2]),
@@ -476,15 +466,15 @@ def __getattr__(self, name):
476466
([], []),
477467
]
478468
for key, result in key_results:
479-
assert _get_column_indices(df_mocked, key) == result
469+
assert _get_column_indices(df, key) == result
480470

481471
msg = "A given column is not a column of the dataframe"
482472
with pytest.raises(ValueError, match=msg):
483-
_get_column_indices(df_mocked, ["not_a_column"])
473+
_get_column_indices(df, ["not_a_column"])
484474

485475
msg = "key.step must be 1 or None"
486476
with pytest.raises(NotImplementedError, match=msg):
487-
_get_column_indices(df_mocked, slice("a", None, 2))
477+
_get_column_indices(df, slice("a", None, 2))
488478

489479

490480
def test_resample():

0 commit comments

Comments
 (0)
0