8000 [MRG] Update test_metaestimators to pass y parameter when calling sco… · amueller/scikit-learn@55d4433 · GitHub
[go: up one dir, main page]

Skip to content

Commit 55d4433

Browse files
oliverrauschamueller
authored andcommitted
[MRG] Update test_metaestimators to pass y parameter when calling score (scikit-learn#12089)
1 parent ab1136d commit 55d4433

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

sklearn/tests/test_metaestimators.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def __init__(self, name, construct, skip_methods=(),
3737
est, param_distributions={'param': [5]}, cv=2, n_iter=1),
3838
skip_methods=['score']),
3939
DelegatorData('RFE', RFE,
40-
skip_methods=['transform', 'inverse_transform', 'score']),
40+
skip_methods=['transform', 'inverse_transform']),
4141
DelegatorData('RFECV', RFECV,
42-
skip_methods=['transform', 'inverse_transform', 'score']),
42+
skip_methods=['transform', 'inverse_transform']),
4343
DelegatorData('BaggingClassifier', BaggingClassifier,
4444
skip_methods=['transform', 'inverse_transform', 'score',
4545
'predict_proba', 'predict_log_proba',
@@ -101,7 +101,7 @@ def decision_function(self, X, *args, **kwargs):
101101
return np.ones(X.shape[0])
102102

103103
@hides
104-
def score(self, X, *args, **kwargs):
104+
def score(self, X, y, *args, **kwargs):
105105
self._check_fit()
106106
return 1.0
107107

@@ -120,15 +120,24 @@ def score(self, X, *args, **kwargs):
120120
msg="%s does not have method %r when its delegate does"
121121
% (delegator_data.name, method))
122122
# delegation before fit raises a NotFittedError
123-
assert_raises(NotFittedError, getattr(delegator, method),
124-
delegator_data.fit_args[0])
123+
if method == 'score':
124+
assert_raises(NotFittedError, getattr(delegator, method),
125+
delegator_data.fit_args[0],
126+
delegator_data.fit_args[1])
127+
else:
128+
assert_raises(NotFittedError, getattr(delegator, method),
129+
delegator_data.fit_args[0])
125130

126131
delegator.fit(*delegator_data.fit_args)
127132
for method in methods:
128133
if method in delegator_data.skip_methods:
129134
continue
130135
# smoke test delegation
131-
getattr(delegator, method)(delegator_data.fit_args[0])
136+
if method == 'score':
137+
getattr(delegator, method)(delegator_data.fit_args[0],
138+
delegator_data.fit_args[1])
139+
else:
140+
getattr(delegator, method)(delegator_data.fit_args[0])
132141

133142
for method in methods:
134143
if method in delegator_data.skip_methods:

0 commit comments

Comments
 (0)
0