8000 Add a few more lines to test_grid_search_results(): · scikit-learn/scikit-learn@5d998f2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5d998f2

Browse files
Eugene Chenraghavrv
Eugene Chen
authored andcommitted
Add a few more lines to test_grid_search_results():
1. check test_rank_score always >= 1 2. check all regular scores (test/train_mean/std_score) and timing >= 0 3. check all regular scores <= 1 Note that timing can be greater than 1 in general, and std of regular scores always <= 1 because the scores are bounded between 0 and 1.
1 parent b6a2618 commit 5d998f2

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

sklearn/model_selection/tests/test_search.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,13 @@ def test_grid_search_results():
659659
for search, iid in zip((grid_search, grid_search_iid), (False, True)):
660660
assert_equal(iid, search.iid)
661661
results = search.cv_results_
662+
# Check if score and timing are reasonable
663+
assert_true(all(results['test_rank_test_score'] >= 1))
664+
assert_true(all(results[k] >= 0) for k in score_keys
665+
if k is not 'rank_test_score')
666+
assert_true(all(results[k] <= 1) for k in score_keys
667+
if not k.endswith('time') and
668+
k is not 'rank_test_score')
662669
# Check results structure
663670
check_cv_results_array_types(results, param_keys, score_keys)
664671
check_cv_results_keys(results, param_keys, score_keys, n_candidates)

0 commit comments

Comments
 (0)
0