-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
Closed
Labels
Description
Describe the bug
GridSearchCV fails at the last step in _format_results()
, if the parameter is of a complex type, such as a dict with mixed key types of a string and a number.
In the code below, the code works for
param_grid=dict(special_param=[{"key1": 1.5, "key2": 18}, None])
but fails for
param_grid=dict(special_param=[{"key1": "some_string", "key2": 18}, None])
Steps/Code to Reproduce
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
class MyLogReg(LogisticRegression):
def __init__(self, special_param=None, C=1):
super().__init__(C=C)
self.special_param = special_param
X, y = make_classification(n_samples=100, random_state=100)
classifier = MyLogReg(C=1)
gs = GridSearchCV(estimator=classifier, cv=3, scoring="f1",
param_grid=dict(special_param=[{"key1": "some_string", "key2": 18}, None]),
verbose=2)
gs.fit(X, y)
print(gs.cv_results_["params"])
Expected Results
Fitting 3 folds for each of 2 candidates, totalling 6 fits
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time= 0.0s
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time= 0.0s
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time= 0.0s
[CV] END .................................special_param=None; total time= 0.0s
[CV] END .................................special_param=None; total time= 0.0s
[CV] END .................................special_param=None; total time= 0.0s
[{'special_param': {'key1': 'some_string', 'key2': 18}}, {'special_param': None}]
Actual Results
Fitting 3 folds for each of 2 candidates, totalling 6 fits
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time= 0.0s
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time= 0.0s
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time= 0.0s
[CV] END .................................special_param=None; total time= 0.0s
[CV] END .................................special_param=None; total time= 0.0s
[CV] END .................................special_param=None; total time= 0.0s
Traceback (most recent call last):
File "~/.config/JetBrains/PyCharm2024.1/scratches/scratch_1.py", line 18, in <module>
gs.fit(X, y)
File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/base.py", line 1473, in wrapper
return fit_method(estimator, *args, **kwargs)
File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/model_selection/_search.py", line 968, in fit
self._run_search(evaluate_candidates)
File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/model_selection/_search.py", line 1544, in _run_search
evaluate_candidates(ParameterGrid(self.param_grid))
File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/model_selection/_search.py", line 962, in evaluate_candidates
results = self._format_results(
File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/model_selection/_search.py", line 1092, in _format_results
arr_dtype = np.result_type(*param_list)
File "~/my_project/.venv/lib/python3.10/site-packages/numpy/core/_internal.py", line 62, in _usefields
names, formats, offsets, titles = _makenames_list(adict, align)
File "~/my_project/.venv/lib/python3.10/site-packages/numpy/core/_internal.py", line 32, in _makenames_list
raise ValueError("entry not a 2- or 3- tuple")
ValueError: entry not a 2- or 3- tuple
Versions
import sklearn; sklearn.show_versions()
System:
python: 3.10.4 (main, Sep 29 2022, 14:13:19) [GCC 11.2.0]
executable: ~/my_project/.venv/bin/python
machine: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
Python dependencies:
sklearn: 1.5.0
pip: 22.2.2
setuptools: 65.3.0
numpy: 1.26.4
scipy: 1.13.1
Cython: None
pandas: 2.2.2
matplotlib: 3.9.0
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 20
prefix: libopenblas
filepath: ~/my_project/.venv/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so
version: 0.3.23.dev
threading_layer: pthreads
architecture: Haswell
user_api: blas
internal_api: openblas
num_threads: 20
prefix: libopenblas
filepath: ~/my_project/.venv/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-01191904.3.27.so
version: 0.3.27
threading_layer: pthreads
architecture: Haswell
user_api: openmp
internal_api: openmp
num_threads: 20
prefix: libgomp
filepath: ~/my_project/.venv/lib/python3.10/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None