8000 mean/std_test_time --> mean/std_time · scikit-learn/scikit-learn@00d3bfd · GitHub
[go: up one dir, main page]

Skip to content

Commit 00d3bfd

Browse files
committed
mean/std_test_time --> mean/std_time
1 parent 96a92f5 commit 00d3bfd

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

sklearn/model_selection/_search.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -590,11 +590,14 @@ def _fit(self, X, y, labels, parameter_iterable):
590590

591591
time = np.array(time, dtype=np.float64).reshape(n_candidates, n_splits)
592592
time_means = np.average(time, axis=1)
593-
time_stds = np.sqrt(
594-
np.average((time - time_means[:, np.newaxis]) ** 2,
595-
axis=1))
593+
time_stds = np.sqrt(np.average((time - time_means[:, np.newaxis]) ** 2,
594+
axis=1))
596595

597596
cv_results = dict()
597+
598+
cv_results["mean_time"] = time_means
599+
cv_results["std_time"] = time_stds
600+
598601
for split_i in range(n_splits):
599602
cv_results["split%d_test_score" % split_i] = test_scores[:,
600603
split_i]
@@ -615,13 +618,12 @@ def _fit(self, X, y, labels, parameter_iterable):
615618
method='min'),
616619
dtype=np.int32)
617620

618-
cv_results["mean_test_time"] = time_means
619-
cv_results["std_test_time"] = time_stds
620-
ranks = np.asarray(rankdata(-test_means, method='min'), dtype=np.int32)
621+
cv_results["rank_test_score"] = np.asarray(rankdata(-test_means,
622+
method='min'),
623+
dtype=np.int32)
621624

622-
best_index = np.flatnonzero(ranks == 1)[0]
625+
best_index = np.flatnonzero(cv_results["rank_test_score"] == 1)[0]
623626
best_parameters = candidate_params[best_index]
624-
cv_results["rank_test_score"] = ranks
625627

626628
# Use one np.MaskedArray and mask all the places where the param is not
627629
# applicable for that candidate. Use defaultdict as each candidate may

sklearn/model_selection/tests/test_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def test_grid_search_cv_results():
657657
'split0_train_score', 'split1_train_score',
658658
'split2_train_score',
659659
'std_test_score', 'std_train_score',
660-
'mean_test_time', 'std_test_time')
660+
'mean_time', 'std_time')
661661
n_candidates = n_grid_points
662662

663663
for search, iid in zip((grid_search, grid_search_iid), (False, True)):
@@ -720,7 +720,7 @@ def test_random_search_cv_results():
720720
'split0_train_score', 'split1_train_score',
721721
'split2_train_score',
722722
'std_test_score', 'std_train_score',
723-
'mean_test_time', 'std_test_time')
723+
'mean_time', 'std_time')
724724
n_cand = n_search_iter
725725

726726
for search, iid in zip((random_search, random_search_iid), (False, True)):

0 commit comments

Comments
 (0)
0