10000 FIX `param_distribution` param of `HalvingRandomSearchCV` accepts li… · REDVM/scikit-learn@6400cb1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6400cb1

Browse files
StefanieSengerglemaitre
authored andcommitted
FIX param_distribution param of HalvingRandomSearchCV accepts list of dicts (scikit-learn#26893)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 3781e6c commit 6400cb1

File tree

4 files changed

+99
-25
lines changed

4 files changed

+99
-25
lines changed

doc/whats_new/v1.3.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ Changelog
2323
:attr:`sklearn.neighbors.KDTree.valid_metrics` as public class attributes.
2424
:pr:`26754` by :user:`Julien Jerphanion <jjerphan>`.
2525

26+
- |Fix| :class:`sklearn.model_selection.HalvingRandomSearchCV` no longer raises
27+
when the input to the `param_distributions` parameter is a list of dicts.
28+
:pr:`26893` by :user:`Stefanie Senger <StefanieSenger>`.
29+
2630
:mod:`sklearn.preprocessing`
2731
............................
2832

sklearn/model_selection/_search_successive_halving.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -750,11 +750,13 @@ class HalvingRandomSearchCV(BaseSuccessiveHalving):
750750
Either estimator needs to provide a ``score`` function,
751751
or ``scoring`` must be passed.
752752
753-
param_distributions : dict
754-
Dictionary with parameters names (string) as keys and distributions
753+
param_distributions : dict or list of dicts
754+
Dictionary with parameters names (`str`) as keys and distributions
755755
or lists of parameters to try. Distributions must provide a ``rvs``
756756
method for sampling (such as those from scipy.stats.distributions).
757757
If a list is given, it is sampled uniformly.
758+
If a list of dicts is given, first a dict is sampled uniformly, and
759+
then a parameter is sampled using that dict as above.
758760
759761
n_candidates : "exhaust" or int, default="exhaust"
760762
The number of candidate parameters to sample, at the first
@@ -1024,7 +1026,7 @@ class HalvingRandomSearchCV(BaseSuccessiveHalving):
10241026

10251027
_parameter_constraints: dict = {
10261028
**BaseSuccessiveHalving._parameter_constraints,
1027-
"param_distributions": [dict],
1029+
"param_distributions": [dict, list],
10281030
"n_candidates": [
10291031
Interval(Integral, 0, None, closed="neither"),
10301032
StrOptions({"exhaust"}),

sklearn/model_selection/tests/test_search.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -900,18 +900,16 @@ def check_cv_results_array_types(search, param_keys, score_keys):
900900
assert cv_results["rank_test_%s" % key].dtype == np.int32
901901

902902

903-
def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand):
903+
def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand, extra_keys=()):
904904
# Test the search.cv_results_ contains all the required results
905-
assert_array_equal(
906-
sorted(cv_results.keys()), sorted(param_keys + score_keys + ("params",))
907-
)
905+
all_keys = param_keys + score_keys + extra_keys
906+
assert_array_equal(sorted(cv_results.keys()), sorted(all_keys + ("params",)))
908907
assert all(cv_results[key].shape == (n_cand,) for key in param_keys + score_keys)
909908

910909

911910
def test_grid_search_cv_results():
912911
X, y = make_classification(n_samples=50, n_features=4, random_state=42)
913912

914-
n_splits = 3
915913
n_grid_points = 6
916914
params = [
917915
dict(
@@ -949,9 +947,7 @@ def test_grid_search_cv_results():
949947
)
950948
n_candidates = n_grid_points
951949

952-
search = GridSearchCV(
953-
SVC(), cv=n_splits, param_grid=params, return_train_score=True
954-
)
950+
search = GridSearchCV(SVC(), cv=3, param_grid=params, return_train_score=True)
955951
search.fit(X, y)
956952
cv_results = search.cv_results_
957953
# Check if score and timing are reasonable
@@ -967,31 +963,35 @@ def test_grid_search_cv_results():
967963
check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
968964
# Check masking
969965
cv_results = search.cv_results_
970-
n_candidates = len(search.cv_results_["params"])
971-
assert all(
966+
967+
poly_results = [
972968
(
973969
cv_results["param_C"].mask[i]
974970
and cv_results["param_gamma"].mask[i]
975971
and not cv_results["param_degree"].mask[i]
976972
)
977973
for i in range(n_candidates)
978-
if cv_results["param_kernel"][i] == "linear"
979-
)
980-
assert all(
974+
if cv_results["param_kernel"][i] == "poly"
975+
]
976+
assert all(poly_results)
977+
assert len(poly_results) == 2
978+
979+
rbf_results = [
981980
(
982981
not cv_results["param_C"].mask[i]
983982
and not cv_results["param_gamma"].mask[i]
984983
and cv_results["param_degree"].mask[i]
985984
)
986985
for i in range(n_candidates)
987986
if cv_results["param_kernel"][i] == "rbf"
988-
)
987+
]
988+
assert all(rbf_results)
989+
assert len(rbf_results) == 4
989990

990991

991992
def test_random_search_cv_results():
992993
X, y = make_classification(n_samples=50, n_features=4, random_state=42)
993994

994-
n_splits = 3
995995
n_search_iter = 30
996996

997997
params = [
@@ -1016,29 +1016,28 @@ def test_random_search_cv_results():
10161016
"mean_score_time",
10171017
"std_score_time",
10181018
)
1019-
n_cand = n_search_iter
1019+
n_candidates = n_search_iter
10201020

10211021
search = RandomizedSearchCV(
10221022
SVC(),
10231023
n_iter=n_search_iter,
1024-
cv=n_splits,
1024+
cv=3,
10251025
param_distributions=params,
10261026
return_train_score=True,
10271027
)
10281028
search.fit(X, y)
10291029
cv_results = search.cv_results_
10301030
# Check results structure
10311031
check_cv_results_array_types(search, param_keys, score_keys)
1032-
check_cv_results_keys(cv_results, param_keys, score_keys, n_cand)
1033-
n_candidates = len(search.cv_results_["params"])
1032+
check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
10341033
assert all(
10351034
(
10361035
cv_results["param_C"].mask[i]
10371036
and cv_results["param_gamma"].mask[i]
10381037
and not cv_results["param_degree"].mask[i]
10391038
)
10401039
for i in range(n_candidates)
1041-
if cv_results["param_kernel"][i] == "linear"
1040+
if cv_results["param_kernel"][i] == "poly"
10421041
)
10431042
assert all(
10441043
(

sklearn/model_selection/tests/test_successive_halving.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import pytest
5-
from scipy.stats import norm, randint
5+
from scipy.stats import expon, norm, randint
66

77
from sklearn.datasets import make_classification
88
from sklearn.dummy import DummyClassifier
@@ -23,7 +23,11 @@
2323
_SubsampleMetaSplitter,
2424
_top_k,
2525
)
26-
from sklearn.svm import LinearSVC
26+
from sklearn.model_selection.tests.test_search import (
27+
check_cv_results_array_types,
28+
check_cv_results_keys,
29+
)
30+
from sklearn.svm import SVC, LinearSVC
2731

2832

2933
class FastClassifier(DummyClassifier):
@@ -777,3 +781,68 @@ def test_select_best_index(SearchCV):
777781
# we expect the index of 'i'
778782
best_index = SearchCV._select_best_index(None, None, results)
779783
assert best_index == 8
784+
785+
786+
def test_halving_random_search_list_of_dicts():
787+
"""Check the behaviour of the `HalvingRandomSearchCV` with `param_distribution`
788+
being a list of dictionary.
789+
"""
790+
X, y = make_classification(n_samples=150, n_features=4, random_state=42)
791+
792+
params = [
793+
{"kernel": ["rbf"], "C": expon(scale=10), "gamma": expon(scale=0.1)},
794+
{"kernel": ["poly"], "degree": [2, 3]},
795+
]
796+
param_keys = (
797+
"param_C",
798+
"param_degree",
799+
"param_gamma",
800+
"param_kernel",
801+
)
802+
score_keys = (
803+
"mean_test_score",
804+
"mean_train_score",
805+
"rank_test_score",
806+
"split0_test_score",
807+
"split1_test_score",
808+
"split2_test_score",
809+
"split0_train_score",
810+
"split1_train_score",
811+
"split2_train_score",
812+
"std_test_score",
813+
"std_train_score",
814+
"mean_fit_time",
815+
"std_fit_time",
816+
"mean_score_time",
817+
"std_score_time",
818+
)
819+
extra_keys = ("n_resources", "iter")
820+
821+
search = HalvingRandomSearchCV(
822+
SVC(), cv=3, param_distributions=params, return_train_score=True, random_state=0
823+
)
824+
search.fit(X, y)
825+
n_candidates = sum(search.n_candidates_)
826+
cv_results = search.cv_results_
827+
# Check results structure
828+
check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates, extra_keys)
829+
check_cv_results_array_types(search, param_keys, score_keys)
830+
831+
assert all(
832+
(
833+
cv_results["param_C"].mask[i]
834+
and cv_results["param_gamma"].mask[i]
835+
and not cv_results["param_degree"].mask[i]
836+
)
837+
for i in range(n_candidates)
838+
if cv_results["param_kernel"][i] == "poly"
839+
)
840+
assert all(
841+
(
842+
not cv_results["param_C"].mask[i]
843+
and not cv_results["param_gamma"].mask[i]
844+
and cv_results["param_degree"].mask[i]
845+
)
846+
for i in range(n_candidates)
847+
if cv_results["param_kernel"][i] == "rbf"
848+
)

0 commit comments

Comments
 (0)
0