8000 GridSearchCV fails if search space contains parameters of a complex type · Issue #29137 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
GridSearchCV fails if search space contains parameters of a complex type #29137
@dmitry-lesnik

Description

@dmitry-lesnik

Describe the bug

GridSearchCV fails at the last step in _format_results(), if the parameter is of a complex type, such as a dict with mixed key types of a string and a number.
In the code below, the code works for

param_grid=dict(special_param=[{"key1": 1.5, "key2": 18}, None])

but fails for

param_grid=dict(special_param=[{"key1": "some_string", "key2": 18}, None])

Steps/Code to Reproduce

from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV


class MyLogReg(LogisticRegression):
    def __init__(self, special_param=None, C=1):
        super().__init__(C=C)
        self.special_param = special_param


X, y = make_classification(n_samples=100, random_state=100)

classifier = MyLogReg(C=1)
gs = GridSearchCV(estimator=classifier, cv=3, scoring="f1",
                  param_grid=dict(special_param=[{"key1": "some_string", "key2": 18}, None]),
                  verbose=2)
gs.fit(X, y)
print(gs.cv_results_["params"])

Expected Results

Fitting 3 folds for each of 2 candidates, totalling 6 fits
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time=   0.0s
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time=   0.0s
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time=   0.0s
[CV] END .................................special_param=None; total time=   0.0s
[CV] END .................................special_param=None; total time=   0.0s
[CV] END .................................special_param=None; total time=   0.0s
[{'special_param': {'key1': 'some_string', 'key2': 18}}, {'special_param': None}]

Actual Results

Fitting 3 folds for each of 2 candidates, totalling 6 fits
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time=   0.0s
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time=   0.0s
[CV] END ..special_param={'key1': 'some_string', 'key2': 18}; total time=   0.0s
[CV] END .................................special_param=None; total time=   0.0s
[CV] END .................................special_param=None; total time=   0.0s
[CV] END .................................special_param=None; total time=   0.0s
Traceback (most recent call last):
  File "~/.config/JetBrains/PyCharm2024.1/scratches/scratch_1.py", line 18, in <module>
    gs.fit(X, y)
  File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/base.py", line 1473, in wrapper
    return fit_method(estimator, *args, **kwargs)
  File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/model_selection/_search.py", line 968, in fit
    self._run_search(evaluate_candidates)
  File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/model_selection/_search.py", line 1544, in _run_search
    evaluate_candidates(ParameterGrid(self.param_grid))
  File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/model_selection/_search.py", line 962, in evaluate_candidates
    results = self._format_results(
  File "~/my_project/.venv/lib/python3.10/site-packages/sklearn/model_selection/_search.py", line 1092, in _format_results
    arr_dtype = np.result_type(*param_list)
  File "~/my_project/.venv/lib/python3.10/site-packages/numpy/core/_internal.py", line 62, in _usefields
    names, formats, offsets, titles = _makenames_list(adict, align)
  File "~/my_project/.venv/lib/python3.10/site-packages/numpy/core/_internal.py", line 32, in _makenames_list
    raise ValueError("entry not a 2- or 3- tuple")
ValueError: entry not a 2- or 3- tuple

Versions

import sklearn; sklearn.show_versions()
System:
    python: 3.10.4 (main, Sep 29 2022, 14:13:19) [GCC 11.2.0]
executable: ~/my_project/.venv/bin/python
   machine: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
Python dependencies:
      sklearn: 1.5.0
          pip: 22.2.2
   setuptools: 65.3.0
        numpy: 1.26.4
        scipy: 1.13.1
       Cython: None
       pandas: 2.2.2
   matplotlib: 3.9.0
       joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 20
         prefix: libopenblas
       filepath: ~/my_project/.venv/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: Haswell
       user_api: blas
   internal_api: openblas
    num_threads: 20
         prefix: libopenblas
       filepath: ~/my_project/.venv/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-01191904.3.27.so
        version: 0.3.27
threading_layer: pthreads
   architecture: Haswell
       user_api: openmp
   internal_api: openmp
    num_threads: 20
         prefix: libgomp
       filepath: ~/my_project/.venv/lib/python3.10/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0