8000 Fix a regression in GridSearchCV for parameter grids that have arrays… · scikit-learn/scikit-learn@f2f4eeb · GitHub
[go: up one dir, main page]

Skip to content

Commit f2f4eeb

Browse files
committed
Fix a regression in GridSearchCV for parameter grids that have arrays of different sizes as parameter values
1 parent a490ab1 commit f2f4eeb

File tree

3 files changed

+67
-29
lines changed

3 files changed

+67
-29
lines changed

doc/whats_new/v1.5.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ Changelog
5858

5959
:mod:`sklearn.utils`
6060
....................
61+
- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
62+
grids that have arrays of different sizes as parameter values.
63+
:pr:`29314` by :user:`Marco Gorelli<MarcoGorelli>`.
6164

6265
- |API| :func:`utils.validation.check_array` has a new parameter, `force_writeable`, to
6366
control the writeability of the output array. If set to `True`, the output array will

sklearn/model_selection/_search.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,36 +1086,33 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10861086
for key, param_result in param_results.items():
10871087
param_list = list(param_result.values())
10881088
try:
1089-
with warnings.catch_warnings():
1090-
warnings.filterwarnings(
1091-
"ignore",
1092-
message="in the future the `.dtype` attribute",
1093-
category=DeprecationWarning,
1094-
)
1095-
8000 # Warning raised by NumPy 1.20+
1096-
arr_dtype = np.result_type(*param_list)
1089+
arr = np.array(param_list)
10971090
except (TypeError, ValueError):
10981091
arr_dtype = np.dtype(object)
10991092
else:
1100-
if any(np.min_scalar_type(x) == object for x in param_list):
1101-
# `np.result_type` might get thrown off by `.dtype` properties
1102-
# (which some estimators have).
1103-
# If finding the result dtype this way would give object,
1104-
# then we use object.
1105-
# https://github.com/scikit-learn/scikit-learn/issues/29157
1106-
arr_dtype = np.dtype(object)
1107-
if len(param_list) == n_candidates and arr_dtype != object:
1108-
# Exclude `object` else the numpy constructor might infer a list of
1109-
# tuples to be a 2d array.
1110-
results[key] = MaskedArray(param_list, mask=False, dtype=arr_dtype)
1111-
else:
1112-
# Use one MaskedArray and mask all the places where the param is not
1113-
# applicable for that candidate (which may not contain all the params).
1114-
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
1115-
for index, value in param_result.items():
1116-
# Setting the value at an index unmasks that index
1117-
ma[index] = value
1118-
results[key] = ma
1093+
arr_dtype = arr.dtype if (arr.dtype.kind != "U") else object
1094+
if len(param_list) == n_candidates:
1095+
try:
1096+
ma = MaskedArray(param_list, mask=False, dtype=arr_dtype)
1097+
except ValueError:
1098+
# Fall back to iterating over `param_result.items()` below
1099+
pass
1100+
else:
1101+
if ma.ndim > 1:
1102+
# If ndim > 1, then a list of tuples might be turned into
1103+
# a 2D array, so we use the fallback below for that case too.
1104+
arr_dtype = object
1105+
else:
1106+
results[key] = ma
1107+
continue
1108+
1109+
# Use one MaskedArray and mask all the places where the param is not
1110+
# applicable for that candidate (which may not contain all the params).
1111+
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
1112+
for index, value in param_result.items():
1113+
# Setting the value at an index unmasks that index
1114+
ma[index] = value
1115+
results[key] = ma
11191116

11201117
# Store a list of param dicts at the key 'params'
11211118
results["params"] = candidate_params

sklearn/model_selection/tests/test_search.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,13 @@
6565
from sklearn.model_selection.tests.common import OneTimeSplitter
6666
from sklearn.naive_bayes import ComplementNB
6767
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
68-
from sklearn.pipeline import Pipeline
69-
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
68+
from sklearn.pipeline import Pipeline, make_pipeline
69+
from sklearn.preprocessing import (
70+
OneHotEncoder,
71+
OrdinalEncoder,
72+
SplineTransformer,
73+
StandardScaler,
74+
)
7075
from sklearn.svm import SVC, LinearSVC
7176
from sklearn.tests.metadata_routing_common import (
7277
ConsumingScorer,
@@ -2724,6 +2729,39 @@ def test_search_with_estimators_issue_29157():
27242729
assert grid_search.cv_results_["param_enc__enc"].dtype == object
27252730

27262731

2732+
def test_cv_results_multi_size_array_29277():
2733+
x = np.linspace(-np.pi * 2, np.pi * 5, 1000)
2734+
y_true = np.sin(x)
2735+
y_train = y_true[(0 < x) & (x < np.pi * 2)]
2736+
2737+
x_train = x[(0 < x) & (x < np.pi * 2)]
2738+
y_train_noise = y_train + np.random.normal(size=y_train.shape, scale=0.5)
2739+
2740+
x = x.reshape((-1, 1))
2741+
x_train = x_train.reshape((-1, 1))
2742+
2743+
spline_reg_pipe = make_pipeline(
2744+
SplineTransformer(extrapolation="periodic"),
2745+
LinearRegression(fit_intercept=False),
2746+
)
2747+
2748+
spline_reg_pipe_cv = GridSearchCV(
2749+
estimator=spline_reg_pipe,
2750+
param_grid={
2751+
"splinetransformer__knots": [
2752+
np.linspace(0, np.pi * 2, n_knots).reshape((-1, 1))
2753+
for n_knots in range(10, 21, 5)
2754+
],
2755+
},
2756+
verbose=1,
2757+
)
2758+
2759+
spline_reg_pipe_cv.fit(X=x_train, y=y_train_noise)
2760+
assert (
2761+
spline_reg_pipe_cv.cv_results_["param_splinetransformer__knots"].dtype == object
2762+
)
2763+
2764+
27272765
@pytest.mark.parametrize(
27282766
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
27292767
)

0 commit comments

Comments
 (0)
0