8000 [MRG+2] Fixed assumption fit attribute means object is estimator. (#8… · NelleV/scikit-learn@9db2e5f · GitHub
[go: up one dir, main page]

Skip to content

Commit 9db2e5f

Browse files
drkatnzNelleV
authored andcommitted
[MRG+2] Fixed assumption fit attribute means object is estimator. (scikit-learn#8418)
1 parent b31367d commit 9db2e5f

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

sklearn/utils/tests/test_validation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.utils.testing import assert_warns_message
1616
from sklearn.utils.testing import assert_warns
1717
from sklearn.utils.testing import ignore_warnings
18+
from sklearn.utils.testing import SkipTest
1819
from sklearn.utils import as_float_array, check_array, check_symmetric
1920
from sklearn.utils import check_X_y
2021
from sklearn.utils.mocking import MockDataFrame
@@ -501,3 +502,15 @@ def test_check_consistent_length():
501502
assert_raises_regexp(TypeError, 'estimator', check_consistent_length,
502503
[1, 2], RandomForestRegressor())
503504
# XXX: We should have a test with a string, but what is correct behaviour?
505+
506+
507+
def test_check_dataframe_fit_attribute():
508+
# check pandas dataframe with 'fit' column does not raise error
509+
# https://github.com/scikit-learn/scikit-learn/issues/8415 8000
510+
try:
511+
import pandas as pd
512+
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
513+
X_df = pd.DataFrame(X, columns=['a', 'b', 'fit'])
514+
check_consistent_length(X_df)
515+
except ImportError:
516+
raise SkipTest("Pandas not found")

sklearn/utils/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _is_arraylike(x):
9292

9393
def _num_samples(x):
9494
"""Return number of samples in array-like x."""
95-
if hasattr(x, 'fit'):
95+
if hasattr(x, 'fit') and callable(x.fit):
9696
# Don't get num_samples from an ensembles length!
9797
raise TypeError('Expected sequence or array-like, got '
9898
'estimator %s' % x)

0 commit comments

Comments
 (0)
0