8000 Improve checks and fix first assertion in test_grid_search_bad_param_… · scikit-learn/scikit-learn@0f1e162 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0f1e162

Browse files
committed
Improve checks and fix first assertion in test_grid_search_bad_param_grid
1 parent 18722f8 commit 0f1e162

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

sklearn/model_selection/_search.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from abc import ABCMeta, abstractmethod
1414
from collections import defaultdict 8000
15-
from collections.abc import Mapping, Iterable
15+
from collections.abc import Mapping, Sequence, Iterable
1616
from functools import partial, reduce
1717
from itertools import product
1818
import numbers
@@ -94,7 +94,8 @@ class ParameterGrid:
9494
def __init__(self, param_grid):
9595
if not isinstance(param_grid, (Mapping, Iterable)):
9696
raise TypeError(
97-
"Parameter grid is not a dict or a list ({!r})".format(param_grid)
97+
f"Parameter grid should be a dict or a list, got: {param_grid!r} of"
98+
f" type {type(param_grid)}"
9899
)
99100

100101
if isinstance(param_grid, Mapping):
@@ -108,11 +109,23 @@ def __init__(self, param_grid):
108109
raise TypeError("Parameter grid is not a dict ({!r})".format(grid))
109110
for key, value in grid.items():
110111
if isinstance(value, np.ndarray) and value.ndim > 1:
111-
raise ValueError("Parameter array should be one-dimensional.")
112-
if not isinstance(grid[key], Iterable):
112+
raise ValueError(
113+
f"Parameter array for {key} should be one-dimensional, got:"
114+
f" {value!r} with shape {value.shape}"
115+
)
116+
if isinstance(value, str) or not isinstance(
117+
value, (np.ndarray, Sequence)
118+
):
113119
raise TypeError(
114-
"Parameter grid value is not iterable "
115-
"(key={!r}, value={!r})".format(key, grid[key])
120+
f"Parameter grid for parameter {key!r} needs to be a list or a"
121+
f" numpy array, but got {value!r} (of type {type(value)})"
122+
" instead. Single values need to be wrapped in a list with one"
123+
" element."
124+
)
125+
if len(value) == 0:
126+
raise ValueError(
127+
f"Parameter grid for parameter {key!r} need "
128+
f"to be a non-empty sequence, got: {value!r}"
116129
)
117130

118131
self.param_grid = param_grid

sklearn/model_selection/tests/test_search.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,13 +444,12 @@ def test_grid_search_bad_param_grid():
444444
param_dict = {"C": 1}
445445
clf = SVC(gamma="auto")
446446
error_msg = re.escape(
447-
"Parameter grid for parameter (C) needs to"
448-
" be a list or numpy array, but got (<class 'int'>)."
449-
" Single values need to be wrapped in a list"
447+
"Parameter grid for parameter 'C' needs to be a list or a numpy array, but got"
448+
" 1 (of type <class 'int'>) instead. Single values need to be wrapped in a list"
450449
" with one element."
451450
)
452451
search = GridSearchCV(clf, param_dict)
453-
with pytest.raises(ValueError, match=error_msg):
452+
with pytest.raises(TypeError, match=error_msg):
454453
search.fit(X, y)
455454

456455
param_dict = {"C": []}

0 commit comments

Comments
 (0)
0