8000 Unexpected behavior when passing multiple parameter sets to RandomizedSearchCV · Issue #18057 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Unexpected behavior when passing multiple parameter sets to RandomizedSearchCV #18057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
justmarkham opened this issue Aug 1, 2020 · 7 comments
Labels
Bug help wanted Moderate Anything that requires some knowledge of conventions and best practices module:model_selection

Comments

@justmarkham
Copy link
Contributor

Describe the bug

Here is part of the documentation for the param_distributions parameter of RandomizedSearchCV:

If a list of dicts is given, first a dict is sampled uniformly, and then a parameter is sampled using that dict as above.

My interpretation is that if I pass a list of two dictionaries, then at each iteration:

  • First, one of the two dictionaries will be chosen at random
  • Then, a set of parameters within that dictionary will be chosen at random

I have found that that is not the case. Instead, I have found that if one of the two dictionaries has many more possible parameter combinations, then the larger dictionary will usually be chosen at each iteration.

Steps/Code to Reproduce

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV

X, y = load_iris(return_X_y=True)
rf = RandomForestClassifier()

# 30 possible combinations
params1 = {}
params1['n_estimators'] = [10, 20, 30, 40, 50]
params1['min_samples_leaf'] = [1, 2, 3, 4, 5, 6]

# 120 possible combinations
params2 = {}
params2['n_estimators'] = [60, 70, 80]
params2['min_samples_leaf'] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
params2['max_features'] = ['auto', None]
params2['bootstrap'] = [True, False]

params_both = [params1, params2]

rand = RandomizedSearchCV(rf, params_both, cv=5, scoring='accuracy', n_iter=50, random_state=1)
rand.fit(X, y)
print(sorted(rand.cv_results_['param_n_estimators']))

Expected Results

Since n_iter=50, I would expect that params1 and params2 would each be chosen about 25 times.

Actual Results

Here is the actual output of the last line:

[10, 20, 30, 40, 40, 50, 50, 50, 50, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 70, 70, 70, 70, 70, 70, 70, 70, 70, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80]

As you can see, params1 (which had values 10 through 50) was chosen only 9 times, and whereas params2 (which had values 60 through 80) was chosen 41 times.

Comments

There seem to be two possibilities:

  1. The current behavior of RandomizedSearchCV is the desired behavior. In that case, I would propose tweaking the documentation to make this behavior more clear.
  2. The current behavior of RandomizedSearchCV is not the desired behavior. In that case, I would propose fixing the behavior so that it matches the documentation.

I don't have a strong feeling about which behavior is the "optimal" behavior.

This was implemented in #14549 by @amueller, so he may have some insight on this!

Versions

System:
    python: 3.8.2 | packaged by conda-forge | (default, Apr 24 2020, 07:56:27)  [Clang 9.0.1 ]
executable: /Users/kevin/miniconda3/envs/sk23/bin/python
   machine: macOS-10.14.6-x86_64-i386-64bit

Python dependencies:
          pip: 20.1.1
   setuptools: 46.4.0.post20200518
      sklearn: 0.23.1
        numpy: 1.18.4
        scipy: 1.4.1
       Cython: None
       pandas: 1.0.3
   matplotlib: 3.2.1
       joblib: 0.15.1
threadpoolctl: 2.0.0

Built with OpenMP: True
@jnothman
Copy link
Member
jnothman commented Aug 1, 2020

I can confirm that the effect seems consistent across random_state. Strange.

@thomasjpfan
Copy link
Member

When the grid is made of all lists it goes down a different path, which flattens the grid:

all_lists = all(
all(not hasattr(v, "rvs") for v in dist.values())
for dist in self.param_distributions)
rng = check_random_state(self.random_state)
if all_lists:

When replacing one of the ranges with a scipy distribution it works as expected:

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
import scipy

X, y = load_iris(return_X_y=True)
rf = RandomForestClassifier()

# 30 possible combinations
params1 = {}
params1['n_estimators'] = [10, 20, 30, 40, 50]
params1['min_samples_leaf'] = [1, 2, 3, 4, 5, 6]

# 120 possible combinations
params2 = {}
params2['n_estimators'] = [60, 70, 80]
params2['min_samples_leaf'] = scipy.stats.uniform(loc=0.25, scale=0.2)
params2['max_features'] = ['auto', None]
params2['bootstrap'] = [True, False]

params_both = [params1, params2]

rand = RandomizedSearchCV(rf, params_both, cv=2, scoring='accuracy', n_iter=50, random_state=1)
rand.fit(X, y)
print(sorted(rand.cv_results_['param_n_estimators']))

and outputs:

# [10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 40, 50, 50, 50, 50, 50, 50, 50, 60, 60, 60, 60, 60, 60, 60, 60, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 80, 80, 80, 80, 80]

Fixing this means we are changing behavior. Should we consider this a bug?

@NicolasHug
Copy link
Member

Should we consider this a bug?

I think so. The easiest fix is probably something like

all_lists = len(self.param_distributions) == 1 and all(
	not hasattr(v, "rvs") for v in self.param_distributions[0].values()) 

But I'm also wondering why we have this weird sampling mechanism:

If all parameters are presented as a list,
sampling without replacement is performed. If at least one parameter
is given as a distribution, sampling with replacement is used.

(Note that this was written long before we introduced support for lists of dicts.)

This behaviour doesn't seem natural and I find it quite unexpected: why should the presence of a scipy dist influence the entire sampling of the dict? Sampling without replacement feels like we're trying to accommodate for cases where GridSearch should be used instead.

Maybe we should deprecate this instead. But I'm not sure how we could do this smoothly?

@jnothman
Copy link
Member
jnothman commented Aug 4, 2020 via email

@NicolasHug
Copy link
Member

Sampling without replacement was introduced because the distribution has a
finite number of distinct samples.

I find it a bit unexpected, especially considering that you can still have a finite number of candidates with discrete distributions anyway. Also, this behavior causes a bug in the len magic (fix for this in #18222)

@amueller
Copy link
Member
amueller commented Aug 21, 2020

So the more natural behavior would be to divide by the number of grid, and then sample without replacement from each flattened grid?

If one of the grids is smaller than the actually allocated size, that'll create weird edge-cases, right?
So if we have 100 iterations and grids of size 10, 50 and 100, the "expected" result is to have 10 from the first, 45 from the second and 45 from the last?

I don't see how this could be computed directly but I guess the naive algorithm to compute the allocations is O(n_dicts ** 2) which isn't so bad. Though you can sort them by size and then do a linear scan so it's O(n_dicts log (n_dicts)).

@cmarmo cmarmo added Bug help wanted Moderate Anything that requires some knowledge of conventions and best practices module:model_selection and removed Bug: triage labels Oct 19, 2020
@hemant3434
Copy link

take

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug help wanted Moderate Anything that requires some knowledge of conventions and best practices module:model_selection
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants
0