8000 Cloned estimators have identical randomness but different RNG instances · Issue #26148 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
8000

Cloned estimators have identical randomness but different RNG instances #26148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
avm19 opened this issue Apr 11, 2023 · 25 comments
Open

Cloned estimators have identical randomness but different RNG instances #26148

avm19 opened this issue Apr 11, 2023 · 25 comments

Comments

@avm19
Copy link
Contributor
avm19 commented Apr 11, 2023

Describe the bug

Cloned estimators have identical randomness but different RNG instances. According to documentation, it should be the other way around: different randomness but identical RNG instances.

Related #25395

The User Guide says:

For an optimal robustness of cross-validation (CV) results, pass RandomState instances when creating estimators

rf_inst = RandomForestClassifier(random_state=np.random.RandomState(0))
cross_val_score(rf_inst, X, y)

...
Since rf_inst was passed a RandomState instance, each call to fit starts from a different RNG. As a result, the random subset of features will be different for each folds

In regards to cloning, the same reference says:

rng = np.random.RandomState(0)
a = RandomForestClassifier(random_state=rng)
b = clone(a)

Moreover, a and b will influence each-other since they share the same internal RNG: calling a.fit will consume b’s RNG, and calling b.fit will consume a’s RNG, since they are the same.

The actual behaviour does not follow this description.

Steps/Code to Reproduce

import numpy as np
from sklearn import clone
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_validate
from sklearn.ensemble import RandomForestClassifier

rng = np.random.RandomState(0)
X, y = make_classification(random_state=rng)
rf = RandomForestClassifier(random_state=rng)

d = cross_validate(rf, X, y, return_estimator=True, cv=2)
rngs = [e.random_state for e in d['estimator']]
# estimators corresponding to different CV runs have different but identical RNGs:
print(rngs[0] is rngs[1]) # False
print(all(rngs[0].randint(10, size=10) == rngs[1].randint(10, size=10))) # True

rf_clone = clone(rf)
rngs = [rf.random_state, rf_clone.random_state]
print(rngs[0] is rngs[1]) # False
print(all(rngs[0].randint(10, size=10) == rngs[1].randint(10, size=10))) # True

Expected Results

True False True False

Actual Results

False True False True

Versions

Tested on a two-week-old dev build and also on the following version (Kaggle)

System:
    python: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:53)  [GCC 9.4.0]
executable: /opt/conda/bin/python3.7
   machine: Linux-5.15.90+-x86_64-with-debian-bullseye-sid

Python dependencies:
          pip: 22.3.1
   setuptools: 59.8.0
      sklearn: 1.0.2
        numpy: 1.21.6
        scipy: 1.7.3
       Cython: 0.29.34
       pandas: 1.3.5
   matplotlib: 3.5.3
       joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True
@avm19 avm19 added Bug Needs Triage Issue requires triage labels Apr 11, 2023
@avm19
Copy link
Contributor Author
avm19 commented Apr 11, 2023

As far as I could see, when an estimator is cloned, random_state attribute gets deepcopied. In base.py:clone, on Line 102 clone() is recursively called on random_state with safe=False, which causes random_state to be deepcopied on Line 83. As a result, an RNG instance is copied when an estimator is cloned.

There are several components to the issue.

  1. When doing cross-validation, we want estimators to have different randomness. We want this to ensure that CV is not affected too badly by lucky or unlucky initialisations of a stochastic estimator.
  2. When cloning an estimator, we want the documentation and the actual behaviour match.
  3. We want reproducibility.

As for (2), the documentation says that clones should refer to one and the same RNG instance. However, if we implement this, don't we risk opening the Pandora Box of concurrency and race condition issues, and thereby killing any hope of reproducibility? I think, it is the documentation should be changed for (2), not the code.

As for (1), cross-validation routines seem to follow the "parallel-delayed-clone" pattern:

results = parallel(
delayed(_fit_and_score)(
clone(estimator),

Thus, one solution for (1) would be to ensure that cloned estimators have different randomness. This randomness should be controllable.

@betatim mentioned what appears to be a nice solution. I quote him:

Something related is https://numpy.org/devdocs/reference/random/parallel.html#seedsequence-spawning (in particular the end of the linked to section) which talks about how to spawn new generators from an existing one. It also links to more material about why seed + fold_idx to generate "reproducible seeds" and such schemes are error prone

I did not have a chance to look deeper into all this, but the solution could be a special treatment of random_state parameter in base.clone. If I am not mistaken, Python guarantees that list comprehension will be executed sequentially in "parallel-delayed-clone", and something like this should work:

    for name, param in new_object_params.items():
        if name=="random_state" and type(param) is RandomState:
            new_object_params[name] = param.spawn()
        else:
            new_object_params[name] = clone(param, safe=False)

P.S. To sum up, the proposed solution is to ensure different randomness and different RNG instances.

@adrinjalali adrinjalali added Documentation and removed Needs Triage Issue requires triage Bug labels Apr 13, 2023
@adrinjalali
Copy link
Member

So @avm19 is right that when cloneing an estimator, a new RNG instance is created via deepcopy, and therefore the documentation is wrong, and needs to be fixed.

Re: random state in scikit-learn, @NicolasHug has done a lot of work, which unfortunately didn't end up anywhere since we didn't manage to reach a consensus (xref SLEP proposal: https://github.com/scikit-learn/enhancement_proposals/pull/24/files) ). random_state as is in scikit-learn has many issues, which are explain in that SLEP draft.

For now, a PR to fix the documentation to match the actual code is welcome. Also, note that if you want the estimators to have different random states, simply don't pass a random state object to them, or pass numpy.random.RandomState(seed=None).

@ogrisel
Copy link
Member
ogrisel commented Apr 13, 2023

So @avm19 is right that when cloneing an estimator, a new RNG instance is created via deepcopy, and therefore the documentation is wrong, and needs to be fixed

Are you sure? I don't think any deepcopy is called in this case. It's just that we keep the same reference to the same RandomState instance passed as a constructor param.

A new RandomState instance should be created when the random_state to the constructor is passed as an int instead.

I think we just need to improve our doc.

@adrinjalali
Copy link
Member

Are you sure?

@ogrisel I think yes, the following code, prints False twice:

# %%
import numpy as np
from sklearn.base import clone
from sklearn.datasets import make_classification
from sklearn.linear_model import SGDClassifier

X, y = make_classification(n_features=5, random_state=42)

# %%
rng = np.random.RandomState(10)
rng2 = clone(rng, safe=False)
print(id(rng) == id(rng2))
# %%
# According to the docs https://scikit-learn.org/stable/common_pitfalls.html#id2
# in particular the subsection on cloning for estimators these two estimators
# should influence each other. They share the `rng` instance.
sgd = SGDClassifier(random_state=rng)
sgd2 = clone(sgd)
print(id(sgd.random_state) == id(sgd2.random_state))

A new RandomState instance should be created when the random_state to the constructor is passed as an int instead.

If the user passes an int, the instance is created anew, true, but with the same seed, and therefore generating the same sequence of numbers, which beats the purpose of what's mentioned in the docs here.

@betatim
Copy link
Member
betatim commented Apr 13, 2023

I think we just need to improve our doc.

I agree with what Olivier said in terms of what should happen (shared reference). This is also what the docs say should happen. However, I don't think this is what actually happens. Otherwise code like https://gist.github.com/betatim/66b6ee6a780ec2e1c54f653eff198d9d would work (in the sense of giving different coef_).

As a result I don't think the docs need improving, instead we need to improve the code.

@betatim
Copy link
Member
betatim commented Apr 13, 2023

And because I've forgotten twice already: thanks for digging into this, reading the docs and reporting this @avm19. This is a topic with many subtleties where we need to get what should happen, what does happen and what the docs say in alignment.

@adrinjalali
Copy link
Member

As a result I don't think the docs need improving, instead we need to improve the code.

For that we'd need to revive the SLEP.

@NicolasHug
Copy link
Member

I am extremely confused.

We've had plenty of discussions with @GaelVaroquaux , @amueller , @ogrisel , @agramfort and @adrinjalali [1] about this behaviour and we all had the same shared (wrong???) understanding that this snippet passes - i.e. that the RandomState instance is shared across clones and that the docs are correct:

rng = np.random.RandomState(0)

rf = RandomForestClassifier(random_state=rng)
rf2 = clone(rf)

assert check_random_state(rf.random_state) is check_random_state(rf2.random_state)

So we either got it all wrong all this time, or something changed in the way we clone(), or in the way np.random.RandomState reacts to being deep-copied?

The current state (wrong docs, RandomState instances not shared across clones) is quite alarming because it means that the only way to get maximal entropy for scikit-learn estimators is to leave random_state=None, which throws reproducibility out of the window. CC @GaelVaroquaux because I know you care about this.

I'm still in the denial phase right now - it has to be a change in clone() or in RandomState right?? I still remember conversations with core devs claiming that yeah, shared RandomState instances across clones is what we have and want. I started writing those docs precisely because I found those randomness subtleties so confusing and error-prone, and they were never documented before. And I still managed to get it completely wrong, even after 10+ people reviewed the PR???

As a result I don't think the docs need improving, instead we need to improve the code.

For that we'd need to revive the SLEP.

Whether we consider this a bugfix or not, implementing the behaviour claimed by the docs is going to be a massive behaviour change - and a silent one. As for SLEP 11, I never submitted it but I wrote a whole new version in https://github.com/NicolasHug/enhancement_proposals/blob/random_state_v2/slep011/proposal.rst. It (wrongly?) assumes RandomState instances are shared across clones but there might be relevant ideas in there as to how we can make all this more explicit and less error-prone.

[1] in scikit-learn/enhancement_proposals#24 in #18363, and offline

@adrinjalali adrinjalali added the Needs Decision Requires decision label Apr 13, 2023
@betatim
Copy link
Member
betatim commented Apr 13, 2023

I think it is worth slowing down, because it is a hairy issue and the impact is also big. The good thing is that its probably been the way it is now for a while, so spending an extra moment to disentangle things is not a big cost.

I also thought that what the docs describe is how it should be as well as how it is.

However, Nicolas' code fails for me on version 1.0.2 (same with the SGD based example). This is doubly weird because the docstring of clone explicitly talks about how random_state is handled in a special way.

@adrinjalali
Copy link
Member

That's very interesting, I also tried with an old (2019) set of versions, and the code fails:

>>> import numpy as np
>>> from sklearn.base import clone
>>> from sklearn.ensemble import RandomForestClassifier
>>> from sklearn.utils import check_random_state
>>> rng = np.random.RandomState(0)
>>> rf = RandomForestClassifier(random_state=rng)
>>> rf2 = clone(rf)
>>> assert check_random_state(rf.random_state) is check_random_state(rf2.random_state)
 assert check_random_state(rf.random_state) is check_random_state(rf2.random_state)

AssertionError: 

>>> sklearn.show_versions()
/home/adrin/miniforge3/envs/delete-me/lib/python3.7/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.
  warnings.warn("Setuptools is replacing distutils.")

System:
    python: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:21)  [GCC 9.4.0]
executable: /home/adrin/miniforge3/envs/delete-me/bin/python
   machine: Linux-6.2.10-arch1-1-x86_64-with-arch

BLAS:
    macros: HAVE_CBLAS=None
  lib_dirs: /home/adrin/miniforge3/envs/delete-me/lib
cblas_libs: blas, cblas, lapack, blas, cblas, lapack, blas, cblas, lapack

Python deps:
       pip: 23.0.1
setuptools: 67.6.1
   sklearn: 0.20.4
     numpy: 1.15.4
     scipy: 1.5.1
    Cython: None
    pandas: None

@agramfort
Copy link
Member
agramfort commented Apr 13, 2023 via email

@NicolasHug
Copy link
Member
NicolasHug commented Apr 14, 2023 R 8000 26;

@agramfort so your expectations are what the code does.

It means that all the RandomForest clones will all have the same RNG across all folds of a CV procedure / grid search. The only way to get maximal entropy i.e. different clones across folds is to leave random_state=None (=> no reproducibility). Passing a RandomState instance or an int currently makes no difference for CV procedures.

@GaelVaroquaux
Copy link
Member
GaelVaroquaux commented Apr 14, 2023 via email

@betatim
Copy link
Member
betatim commented Apr 18, 2023

I've added this to the agenda for the next monthly meeting (24th April). I think having a short sync discussion would be helpful, even if it won't answer all questions "once and for all".

@thomasjpfan
Copy link
Member

During the Drafting meeting, @adrinjalali @ogrisel @glemaitre @betatim and I discussed many aspects of random state.

Here is my summary of the discussion:

  1. We agree on randon_state=None and random_state=int behavior. The hard part is the random state objects.
  2. We agree that we should leave the current behavior alone. I think the docs should be updated to reflect the actual behavior.
  3. We want to use the move to NumPy Generators to support the desired behavior. We do not need to worry about backward compatibility, because NumPy Generators are new objects to pass around.
  4. We reconsidered SLEP011 and are undecided.
  5. There is no API contract about randomness for multiple calls to predict. In scikit-learn, I think we are consistent and that multiple calls to predict with the same data give the same output.
  6. We have a conflict about what the desired behavior of the following should be:
rng = np.random.RandomState(42)
rf1 = RandomForestClassifier(random_state=rng).fit(X, y)
rf2 = RandomForestClassifier(random_state=rng).fit(X, y)

# Should the above give the same model, and thus the same predictions?
np.assert_array_equal(rf1.predict(X), rf2.predict(X))

Cross validation and max entropy

We agree that in any cross validation such as GridSearchCV should give the maximum amount of entropy with a way to make it deterministic. As described in #26148 (comment) there is no way to make it possible now and be deterministic. We did not come up with a good way to do it with Generators either.

Estimators that know what they are spawning (RandomForestClassifer) do not have this issue and can use SeedSequences. Concretely, RandomForestClassifer can pass the random state objects to the tree directly because the forest knows it is dealing with trees. On the other hand, GridSearchCV can not pass this random state down because there is no API for passing these objects to an estimator.

@GaelVaroquaux
Copy link
Member
GaelVaroquaux commented May 6, 2023 via email

@adrinjalali
Copy link
Member

The point we made, which is lost in the summary, is that people should be able to have randomized predict, but they should also be able to control that randomness if needed.

@lorentzenchr
Copy link
Member
lorentzenchr commented May 11, 2023

My take on point 6 in #26148 (comment) is that it should result in 2 different random forests, thus the assert should fail. Otherwise we make it hard for someone who knows what she is doing with rng.

Can't we make GridSearchCV aware that the (meta) estimator (eg. pipeline) supports a random_state argument? SLEP006 maybe?
If grid search knows that, it can clone and then set the random state of the cloned object.

@betatim
Copy link
Member
betatim commented May 11, 2023

I wonder about point 6: people might use a predictor that sample the posterior, and be surprised if multiple calls give the same samples.

I think you can argue that this is a case that is different from what the code snippet shows. At least for me what you point out is more akin to this:

clf = FooWithSamplingClassifier(random_state=<int or random_state>).fit(X, y)

# calling once
pred1 = clf.predict(X)
# calling twice
pred2 = clf.predict(X)
np.assert_array_not_equal(pred1, pred2)

Where FooWithSamplingClassifier is a classifier that does some sampling as part of its prediction.

For me the take away is that predict() can't make promises because the right answer depends on the estimator. Some estimators are fully deterministic (repeated calls to predict with same X will return same prediction), some estimators use randomness to perform sampling (repeated calls to predict with same X will return different predictions) and some estimators use randomness to make predictions but there should be no history (repeated calls to predict with same X will return the same prediction). The last example is my understanding of a example Adrin brought up. It is an algorithm that uses randomness to make its prediction but importantly the prediction should not depend on how many times predict has been called previously.

These three use-cases all seem legitimate so I think we can't enforce a particular contract, it will depend on the estimator. It also means whether a int or random_state is passed in doesn't matter, you need to read the estimator's docs to know how it will behave.

My take on point 6 in #26148 (comment) is that it should result in 2 different random forests, thus the assert should fail. Otherwise we make it hard for someone who knows what she is doing with rng.

I think a person's expectation of what happens will depend on whether they think RandomState(42) is a mutable object or not. If you think that it is mutable then it makes sense that rf1 and rf2 are different. If you think that it is an immutable object then rf1 and rf2 should be the same. I don't know what my naive/unbiased by history expectation is about (im)mutability of RandomState objects. I've learnt they are mutable. However, I also know that at some point in their life every Python programmer falls into the "mutable object as default argument value" trap (def foo(arg=[]):). And it is tricky to explain to novice Python users why it totally makes sense that things work this way.

This makes me wonder if life would be simpler for everyone if we considered them immutable. Using the JAX notation (because you can go and read about it). To get two estimators that are the same you'd do:

from jax import random
key = random.PRNGKey(0)

rf1 = RandomForestClassifier(random_state=key)
rf2 = RandomForestClassifier(random_state=key)

To have two different estimators you'd split the key before passing it to the constructor:

from jax import random
key = random.PRNGKey(0)
key, rf1_key, rf2_key = random.split(key)
rf1 = RandomForestClassifier(random_state=rf1_key)
rf2 = RandomForestClassifier(random_state=rf2_key)

I quite like this. It is only marginally more complex to understand for a user than using integers as seeds and like with seeds you can have both kinds of behaviours. I think it is also possible to do this with the new random number generator infrastructure of Numpy?!

(n.b. I have rewritten my comment three or four times over the course of the last hour, and each time I changed my conclusion. But I think I like this version now, but who knows, if I think some more about it I might change my mind again??)

@NicolasHug
Copy link
Member

I hope I'm not bringing additional confusion here but I think the last few comments may have been mixing up point 5 and point 6.

As far as I understand Point 5 is about "randomized predict" i.e. whether est.predict(X); est.predict(X) should be the same preds, while point 6 relates to the (possibly consumed) RNG happening in fit().

@betatim
Copy link
Member
betatim commented May 12, 2023

I think you are right Nicolas :-/ Sorry.

Having thought about my previous comment regarding JAX Keys and splitting and having a discussion with @seberg I think I am on the cusp of changing my mind again compared to what I wrote :-/

Sebastian point 6D40 ed out that either option (fit modifies the state passed in/fit leaves it unmodified) can be a foot-gun. If most people want the two RFs from my example to be independent then it is annoying to make them split the key every time (and people will forget to do so, hence foot-gunning themselves). Same argument applies in reverse if most users want the estimators to see the same randomness.

We can make both options happen. One by teaching the minority that they need to call split (default is same randomness), the other by teaching the minority that they need to copy (default is not the same randomness). I think the important thing is that we don't end up having to teach the majority about an extra step.

@NicolasHug
Copy link
Member
NicolasHug commented May 12, 2023

OK.

I've been reluctant to publish scikit-learn/enhancement_proposals#88 (rendered docs are here) so far because I won't be able to champion it has become slightly out of date. But I think the SLEP does a decent job at describing the key requirements we need to support for a solution, so I hope it can help better framing the discussion.

FWIW I think the solution advocated in the SLEP is quite similar to what you suggested above @betatim. The main idea is to 1) make estimators (and splitters) fully stateless across calls to fit(), and 2) expose the randomness complexity to users by letting them choose the kind of CV procedure they want to run (with good defaults). Right now this complexity is both hidden and largely uncontrollable.

For completeness, a toy implementation of the solution is implemented in this notebook.

@seberg
Copy link
Contributor
seberg commented May 12, 2023

Not coming from the sklearn boat, so I may well be completely off and happy to be ignored, but on both things I am not sure I can fully agree. I would probably argue that users should never re-use identical pseudo randomness unless they are explicit about it.
(For the case where a user calls .clone(), maybe that is explicit enough.)

Because of that, I will ask what the benefit is if calling .fit(X, y) twice would give the same result?

  • If the user calls it with identical data, I think things are pretty simple:
    • They want identical results: They will just notice its not identical and change the code.
    • They don't want identical results: It would be a subtle bug to get identical results (one that I bet many will never notice).
  • If the user calls it with different data:
    • They will definitely not notice re-using the same randomness over-and-over. And it may be a big error depending on the context.
    • Some users might want to run everything e.g. with the same initial conditions. I am unsure if this is actually common. My gut feeling would be that these are more likely to be expert users than the other group. (However, this also might be very subtle to notice that you are not getting what you expected.)

My suspicion would be that most (especially inexperienced?) users will almost never need identical streams and will end up with a subtly bad rng if that is what you give them. But maybe I am underestimating how often it is required? OTOH, maybe when you do there are often other ways (such as explicitly setting the initial conditions).

TBH, I am not sure it is important whether users lean towards "of course calling .fit(X, y) twice will lead to the same result". Did that intuition even take into account randomness? And even if it did, will their code subtly rely on it? (Or what is the path that is more likely to cause subtle issues? I would argue the current random_state=integer is that. All you wanted is to make the plots you get at the end deterministic, but what you got is the same initial condition on all CV splits!)

@lorentzenchr
Copy link
Member

@seberg The question is more what the result of calling fit on the same data with a cloned estimator or with the same parametrized estimator and identical object passed as rng parameter should be.

Meanwhile, I think discussing options for „point 6“ in #26148 (comment) does not solve our issue(s). Both options are fine and can be used to implement the same functional behavor. It’s then more a question of having it documented and tested.

IMHO, @NicolasHug initiated the right approach with his draft SLEP: write down requirements and search for solutions by extending estimators use case by use case. Ideally, we don’t bother users at all with randomness (and thereby protect from the danger of serious foot injuries) and still provide swiss army knife flexibility for experts.

I hope, for instance, that we all agree on requirement 1: It should be very easy or be the default to have exact reproducibility (when running the same python code in different python processes in the same order).

@seberg
Copy link
Contributor
seberg commented May 12, 2023

Right, I had more things but its complicated and of course mapping out things is important. I suppose I have a few questions (opinions):

  • I don't really like the idea of rng=integer and rng=np.random.default_rng(integer) behaving different (they can be result in different values, but not in subtly diffrerent behavior on cloning).
  • I don't really like the idea that users might use the same random numbers twice unless they are explicit about it. With NumPy itself this is almost impossible.
  • However, there seems to be a thing around CV.split() where doing that may even be the clearly expected thing.
    • Question: Could there be a different way to spell this? CV.rewind(), or CV(rng=1234, repeat=True) or...?

The thing about CV and estimators being in a different boat seems potentially super tricky, and yes, maybe cloning actually cloning the state too makes that work. But, maybe it is worth considering if there are orthogenal solutions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

10 participants
0