8000 FIX SkLearn `.score()` method generating error with Dask DataFrames (… · lithuak/scikit-learn@5d8dfc9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5d8dfc9

Browse files
ZWMillerjnothman
authored andcommitted
FIX SkLearn .score() method generating error with Dask DataFrames (scikit-learn#12462)
1 parent 6b4e00d commit 5d8dfc9

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

doc/whats_new/v0.20.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,9 @@ Miscellaneous
13141314
happens immediately (i.e., without a deprecation cycle).
13151315
:issue:`11741` by `Olivier Grisel`_.
13161316

1317+
- |Fix| Fixed a bug in validation helpers where passing a Dask DataFrame results
1318+
in an error. :issue:`12462` by :user:`Zachariah Miller <zwmiller>`
1319+
13171320
Changes to estimator checks
13181321
---------------------------
13191322

sklearn/utils/tests/test_validation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
check_memory,
4242
check_non_negative,
4343
LARGE_SPARSE_SUPPORTED,
44+
_num_samples
4445
)
4546
import sklearn
4647

@@ -786,3 +787,15 @@ def test_check_X_y_informative_error():
786787
X = np.ones((2, 2))
787788
y = None
788789
assert_raise_message(ValueError, "y cannot be None", check_X_y, X, y)
790+
791+
792+
def test_retrieve_samples_from_non_standard_shape():
793+
class TestNonNumericShape:
794+
def __init__(self):
795+
self.shape = ("not numeric",)
796+
797+
def __len__(self):
798+
return len([1, 2, 3])
799+
800+
X = TestNonNumericShape()
801+
assert _num_samples(X) == len(X)

sklearn/utils/validation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,12 @@ def _num_samples(x):
140140
if len(x.shape) == 0:
141141
raise TypeError("Singleton array %r cannot be considered"
142142
" a valid collection." % x)
143-
return x.shape[0]
143+
# Check that shape is returning an integer or default to len
144+
# Dask dataframes may not return numeric shape[0] value
145+
if isinstance(x.shape[0], numbers.Integral):
146+
return x.shape[0]
147+
else:
148+
return len(x)
144149
else:
145150
return len(x)
146151

0 commit comments

Comments
 (0)
0