8000 Fix tests; Unsupervised multimetric gs will not pass until #8117 is m… · raghavrv/scikit-learn@636086d · GitHub
[go: up one dir, main page]

Skip to content

Commit 636086d

Browse files
committed
Fix tests; Unsupervised multimetric gs will not pass until scikit-learn#8117 is merged
1 parent 00dbd01 commit 636086d

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

sklearn/model_selection/tests/test_search.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -493,10 +493,12 @@ def test_X_as_list():
493493
cv = KFold(n_splits=3)
494494

495495
for scoring in (None, 'accuracy', ('accuracy', ),
496-
('accuracy', 'precision')):
496+
('accuracy', 'recall')):
497497
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv,
498-
scoring=scoring)
499-
grid_search.fit(X.tolist(), y).score(X, y)
498+
scoring=scoring,
499+
refit='accuracy'
500+
if scoring and len(scoring) > 0 else True)
501+
grid_search.fit(X.tolist(), y).score(X.tolist(), y)
500502
assert_true(hasattr(grid_search, "cv_results_"))
501503

502504

@@ -509,10 +511,12 @@ def test_y_as_list():
509511
cv = KFold(n_splits=3)
510512

511513
for scoring in (None, 'accuracy', ('accuracy', ),
512-
('accuracy', 'precision')):
514+
('accuracy', 'recall')):
513515
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv,
514-
scoring=scoring)
515-
grid_search.fit(X, y.tolist()).score(X, y)
516+
scoring=scoring,
517+
refit='accuracy'
518+
if scoring and len(scoring) > 0 else True)
519+
grid_search.fit(X, y.tolist()).score(X, y.tolist())
516520
assert_true(hasattr(grid_search, "cv_results_"))
517521

518522

0 commit comments

Comments
 (0)
0