8000 Support early_stopping with custom validation_set · Issue #18748 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Support early_stopping with custom validation_set #18748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
deltawi opened this issue Nov 3, 2020 · 27 comments
Open

Support early_stopping with custom validation_set #18748

deltawi opened this issue Nov 3, 2020 · 27 comments

Comments

@deltawi
Copy link
deltawi commented Nov 3, 2020

Describe the workflow you want to enable

Today in SGDClassifier, the parameter early_stopping uses a fraction of the data randomly, it would be useful to support a custom validation set chosen by the user.

Describe your proposed solution

for example:

clf = SGDClassifier(early_stopping=True)
clf.fit(X_train, y_train, eval_set=(X_val, y_val))

EDIT

Broader Scope

Same applies to GradientBoosting* and HistGradientBoosting*

@TomDLT
Copy link
Member
TomDLT commented Nov 19, 2020

I agree this could be valuable. (see Nicolas' comment)

Currently, the validation split is implemented as a mask array validation_mask, which is used to define X_val = X[validation_mask], and X_train = X[1 - validation_mask]. We could thus avoid adding a new parameters by accepting a predefined mask array in the parameter validation_fraction.

We would need to:

  • update _make_validation_split with an early return if self.validation_fraction is an array
  • validate the array: cast it to uint8 dtype, check that its length matches X's
  • update the documentation
  • add a test

Do you want to give it a try ?

@NicolasHug
Copy link
Member

I think we've tried to avoid sample-aligned init parameters so far: passing a sampled-aligned mask in __init__ would mean that the estimator wouldn't be usable in any cross-validation helper, like cross_val_score or GridSearchCV.

For ref there were related discussions in #15127 and #5915

@TomDLT
Copy link
Member
TomDLT commented Nov 19, 2020

Makes sense, forgot about this consideration.

@adrinjalali
Copy link
Member

The issue I have with passing validation_data to models' fit is that I'm not sure how we'll manage it when the model is in a pipeline. The validation_data needs to go through the same preprpocessing steps of a pipeline as the rest of the data. The user thus cannot set aside a part of the data and pass it directly to the model at the end of a pipeline.

In many of the discussions I see here and elsewhere, people seem to not have that requirement?

@TomDLT
Copy link
Member
TomDLT commented Nov 21, 2020

Scikit-learn's API is centered around easy cross-validation, but I guess a lot of people use their own cross-validation tools, for better or worse.

@docmarionum1
Copy link

+1, would find this very useful across the different models that support early stopping validation.

TomDLT's idea of using indices seems like it could address adrinjalali's concern of preprocessing since it would still be split off from the training data once it gets to the model. That, or providing group labels and supporting GroupShuffleSplit if the user provides group labels to fit.

I have data where points from each group are highly correlated. I'm using group splitters to validate outside of training, but since the training data is randomly split for validation (in my case by HistGradientBoostingClassifier), early stopping doesn't work as intended.

Of course, understanding that this has a lot of consequences outside of just these models, since sklearn has a consistent API for fit, but it would be very useful :)

@eddiewu01
Copy link

+1, would find this feature very useful. In my work we have hand crafted train and validation set, and ideally we want to stop training once performance on validation set becomes worse for some number of iters.

@adrinjalali
Copy link
Member
8000

I guess we could think of an API where pipeline's fit accepts a validation_set, as well as many other estimators and all other meta-estimators, and handle that properly. But that's a quite a large project to pull off.

@eddiewu01
Copy link

Could we start a PR and work on this collectively? I still need to understand how sklearn API works in general but would love to learn more and contribute!

@thomasjpfan
Copy link
Member

Could we start a PR and work on this collectively? I still need to understand how sklearn API works in general but would love to learn more and contribute!

Given the complexity of the issue, I would start with proposing an API here first after learning about scikit-learn's API. Resolving this issue involves adding new API to estimators that has early stopping and requires thinking about how meta-estimators like pipeline interacts with this new API.

Side Note: skorch.net.NeuralNet, has a train_split __init__ parameter, which is a callable that splits the data into train & validation sets. For predefined splits, there is a predefined helper, which allows for custom splits. For this workflow to integrate nicely with pipelines, the Pipeline would also need a train_split parameter, do the correct preprocessing on the validation set, and pass that information into a train_split __init__ parameter in the last step.

@adrinjalali
Copy link
Member

Would we allow the same validation for transformers as well? Imagine an encoder at the first step of the pipeline, which would have its own early stopping criteria. Is the validation set used for that step the same as the validation set used for the last step of the pipeline?

@thomasjpfan
Copy link
Member
8000

Imagine an encoder at the first step of the pipeline, which would have its own early stopping criteria. Is the validation set used for that step the same as the validation set used for the last step of the pipeline?

If we want to be simple, then yes. The transformer would get the non-transformed version to validation and the final step would get the transformed version.

If we try to place the validation set into fit kwargs, then the issue can be considered a more complicated metadata routing problem with SLEP006. It can use the infrastructure of metadata routing to configure an estimator's request, but the actually routing is different compared to sample_weights. For example, Pipeline would need to actually transform the validation set for the final step.

@ogrisel
Copy link
Member
ogrisel commented Oct 29, 2021

Note that auto-splitting from the training set inside the final classifier/regressor is problematic when this estimator is wrapped in a rebalancing meta-estimator to tackle target imbalance problems: rebalancing should happen only on training data while early stopping, model selection and evaluation should only use metrics computed using originally balanced data.

I am not sure an auto-magical API would work for this. Making it possible to pass a manually prepared validation might be the sanest way to deal with this situation.

@ogrisel
Copy link
Member
ogrisel commented Oct 29, 2021

I made a similar point in #15127 (comment).

@lorentzenchr
Copy link
Member

Does this issue benefit from SLEP006 metadata routing? If yes, maybe an example code would be enough?

@adrinjalali
Copy link
Member
adrinjalali com 8000 mented Jul 31, 2023

I think if we keep the validation set fixed, then yes.

@lorentzenchr
Copy link
Member
lorentzenchr commented Aug 3, 2023

I think if we keep the validation set fixed, then yes.

That‘ll work. Is it then a parameter to fit (I think so) or init?
Do we expect a y_validation, X_validation, sample_weight_validation or indices or a splitter? And how to name that argument?

LighGBM has eval_set (tuple (y, X)) and eval_sample_weight in fit.
XGBoost also has eval_set and sample_weight_eval_set.

@adrinjalali
Copy link
Member

So with metadata routing this would work:

X, y = load_iris()
X, y, X_eval, y_eval = train_test_split(X, y)

preprocessing = Pipeline(...) # preprocessing steps
X = preprocessing.fit_transform(X, y)
X_eval = preprocessing.transform(X_eval)

gs = GridSearchCV(
    HistGradientBoostingClsasifier().set_fit_request(X_eval=True, y_eval=True),
    ...
)

gs.fit(X, y, X_evel=X_eval, y_eval=y_eval)

But you can't put any preprocessing in a pipeline and pass that to the grid search, since in grid search the preprocessing steps are re-fit with new parameters if we're tuning on them, and then X_eval is not transformed the way it should be.

I think that's already a good step, but we'd need to also modify pipeline to handle things which are supposed to be transformed before feeding to next steps.

Does that make sense?

@lorentzenchr
Copy link
Member

@adrinjalali Thanks. Yes, that makes sense. If nobody else is working on it or intents to do so, I'll soon open a PR for HGBT.

@lorentzenchr
Copy link
Member

The consensus of a longer discussion at the drafting meeting 19.01.2024 was to go with passing splitter objects (option 2) as parameters to the estimator (e.g. HistGradientBoosting*).
To recap, discussed options were:

Number Option  Pro Con
1 CV-etsimators like LogisticRegressionCV ugly & maintenance
2 Pass a splitter as constructor argument works some data leakage
3 Passing a pre-computed validation dataset X_val, y_val works X_val can't be processed/created in a pipeline
4 Use a callback might work not yet accepted/implemented
Problem 1 2 3 4
split on a column that should not be used for fitting no yes via metadata-routing yes ?
X_val requires same preprocessing as "X_train" yes yes no ?
data leakage some some depends ?

@jeremiedbb
Copy link
Member
jeremiedbb commented Feb 29, 2024

To be more precise regarding the above chart, callbacks are not an alternative to provide the validation set. In order to have interesting callbacks (early stopping, monitoring, ...), the validation set must but accessible when the callback is called. However the solution we chose for a unified API to provide a validation is kind of orthogonal to callbacks API.

Then, once this question is solved, an early stopping callback will be imo a good solution to have a consistent API for early stopping across all estimators, instead of each estimator implementing its own version of early stopping.

@amueller
Copy link
Member
amueller commented Mar 8, 2024

I disagree with any of the "works" above. If you have a pipeline that consists of feature selection and an estimator, the feature selection will include the validation data, leading to potentially very misleading results.
The only solution I can see is adding a mechanism to Pipeline that knows how to treat a validation set, potentially piggy-backing on a modified version of the meta-data routing.
That's captured in the "Problem" table, but "having some data leakage" in validation doesn't seem acceptable to me. ps I really wanted to make the draft meeting this week but I've been sick in and out of bed the whole week :(

@jeremiedbb
Copy link
Member

The only solution I can see is adding a mechanism to Pipeline that knows how to treat a validation set, potentially piggy-backing on a modified version of the meta-data routing.

This is proposed here #28440 (comment)

but "having some data leakage" in validation doesn't seem acceptable to me

Yet it's how it's been done in scikit-learn until now (HGBT, SGD, ...). But I agree that it does not prevent us from finding a better solution :)
However, the better solution involves an evolution of the API and a lot of work so it should be clearly motivated. That is show a concrete realistic example where the data leakage coming from the preprocessing is really detrimental. We discussed that irl with @glemaitre but I haven't started yet.

ps I really wanted to make the draft meeting this week but I've been sick in and out of bed the whole week :(

Let's discuss it in the next one on march 22nd. I'll send a mail on the mailing list to make others aware.
We could also take the opportunity to discuss the callback API with the support of the newly writen slep if you'd like.

@lorentzenchr
Copy link
Member
lorentzenchr commented Mar 8, 2024

@amueller Are we talking about the same problem? In my understanding, the issue here is early stopping in the final estimator of a pipeline, not ES in a preprocessing step before that, nor an unbiased estimate of the out-of-sample performance à la cross validation. Sure, the preprocessing has an effect on the final estimator, but ES should just avoid overfitting or spare resources of the final estimator. As @betatim stated, #28440 (comment), as long as the validation curve is only y-shifted, everything is fine.

Also sure, a mechanism to tell a pipeline to pass validation data through it is a missing piece in our API (even after metadata routing) and would solve this issue in the methodological soundest way (or not sound at all, we just put the burden on the user). It was proposed by @adrinjalali, #28440 (comment), and, IMHO, should be proposed in an own issue or even in a SLEP.

@thomasjpfan
Copy link
Member

The issue around feature selection in a pipeline is described here: #28440 (comment). Concretely, something like this:

pipe = make_pipeline(
    SequentialFeatureSelector(...),
    HistGradientBoostingClassifier(early_stopping=True),
)

@amueller
Copy link
Member
amueller commented Mar 25, 2024

@lorentzenchr Sorry I was busy and missed the March 22nd draft meeting (which is also somewhat inconvenient for my 7am but I would have made it if I'd known this was the issue discussed).
And I don't see why the validation curve would be y-shifted only. I have a ICML rebuttal deadline this week, so I won't have time to develop the example this week, but in essence it's the pipeline that @thomasjpfan proposed. Early stopping is a hyper-parameter selection problem. Whether the validation set was considered in the feature selection or not creates completely different validation sets for the hyper-parameter selection problem, and I don't see how they would be related in this case.

@jeremiedbb:

Yet it's how it's been done in scikit-learn until now (HGBT, SGD, ...). But I agree that it does not prevent us from finding a better solution :)

It's been done that way in single estimators, and was impossible to do in pipelines. I.e. the issue we're talking about was absent because doing the wrong thing was impossible. Doing the right thing was also impossible, though.

@ogrisel
Copy link
Member
ogrisel commented Aug 1, 2024

Cross-linking to a PR that stemmed from the discussion in #28440:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

0