8000 Revert "avoid branching, handle higher-dim case" · scikit-learn/scikit-learn@06e35b6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 06e35b6

Browse files
committed
Revert "avoid branching, handle higher-dim case"
This reverts commit 44af268.
1 parent 44af268 commit 06e35b6

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

sklearn/model_selection/_search.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,15 +1079,16 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10791079
param_results["param_%s" % name][cand_idx] = value
10801080
for key in param_results:
10811081
arr = np.array(list(param_results[key].values()))
1082-
# Use one MaskedArray and mask all the places where the param is not
1083-
# applicable for that candidate (which may not contain all the params).
1084-
ma = MaskedArray(
1085-
np.empty((n_candidates, *arr.shape[1:])), mask=True, dtype=arr.dtype
1086-
)
1087-
for index, value in param_results[key].items():
1088-
# Setting the value at an index unmasks that index
1089-
ma[index] = value
1090-
param_results[key] = ma
1082+
if len(arr) == n_candidates:
1083+
param_results[key] = MaskedArray(arr, mask=False)
1084+
else:
1085+
# Use one MaskedArray and mask all the places where the param is not
1086+
# applicable for that candidate (which may not contain all the params).
1087+
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr.dtype)
1088+
for index, value in param_results[key].items():
1089+
# Setting the value at an index unmasks that index
1090+
ma[index] = value
1091+
param_results[key] = ma
10911092

10921093
results.update(param_results)
10931094
# Store a list of param dicts at the key 'params'

0 commit comments

Comments
 (0)
0