8000 ENH Allow prefit in stacking by Micky774 · Pull Request #22215 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Allow prefit in stacking #22215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Feb 22, 2022

Conversation

Micky774
Copy link
Contributor
@Micky774 Micky774 commented Jan 14, 2022

Reference Issues/PRs

Fixes #16556
Closes #16748
Continuation of stalled PR #16748

What does this implement/fix? Explain your changes.

(PR #16748):
Added support to use pre-fit model in StackingClassifier and StackingRegressor

Similar to CalibratedClassifierCV, I added the option to make cv = "prefit" to use fitted estimators into a stacking model.

(This PR)
Resolves remaining PR concerns, mainly regarding testing.

Any other comments?

Continuation of stalled PR #16748

- Updated `sklearn/ensemble/tests/test_stacking` to incorporate
suggested changes from PR#16748
@Micky774 Micky774 changed the title [WIP] ENH Allow prefit in stacking ENH Allow prefit in stacking Jan 15, 2022
Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @Micky774! Overall looks good.

We can add a sentence to:

estimators_ : list of estimators

that says what happens when cv="prefit"

@Micky774
Copy link
Contributor Author

Implemented changes, and made the documentation more consistent between StackingClassifier and StackingRegressor

Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

Micky774 and others added 6 commits January 18, 2022 14:47
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
@thomasjpfan
Copy link
Member
thomasjpfan commented Jan 25, 2022

@glemaitre Would you interested in reviewing this PR?

@glemaitre
Copy link
Member

Yep let me put it in my stack

@glemaitre glemaitre self-requested a review January 26, 2022 09:20
@thomasjpfan
Copy link
Member

@jnothman Maybe you would be interested in reviewing this one? This is mostly the same as #16748 which already had your approval.

Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code seems fine. I have just this question regarding the API that I am not sure about.


if self.cv == "prefit":
# Generate predictions from prefit models
predictions = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasjpfan do you think that we could benefit from parallelization over the models here?

Copy link
Member
@thomasjpfan thomasjpfan Feb 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there can be a benefit, but this can be done in a follow up PR with some benchmarks.

@@ -306,7 +321,7 @@ class StackingClassifier(ClassifierMixin, _BaseStacking):
The default classifier is a
:class:`~sklearn.linear_model.LogisticRegression`.

cv : int, cross-validation generator or an iterable, default=None
cv : int, cross-validation generator, iterable, or 'prefit', default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the description above, we have a note regarding the fit of estimators_ on the full training set. It should be complemented with the option prefit (at least stating that we don't refit if cv="prefit").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a description of the role of cv="prefit" in both the description of the cv parameter, as well as estimators_. I'm not sure I quite understand what you're suggesting here.

for estimator in all_estimators:
if estimator != "drop":
check_is_fitted(estimator)
self.estimators_.append(estimator)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am unsure about our API contract here. If someone alters an estimator from all_estimators without calling fit of the stacking model, it will however have an effect on the prediction.

Thinking about the freezing API, changing a hyperparameter and calling fit will have no effect. So the deep copy is not necessary to prevent the behaviour that I described above. However, without this freezing API here, we would need to make a deep copy of each estimator to prevent any side effects.

@thomasjpfan WDYT? I see that we have a similar pattern in calibration.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an example to explicitly illustrate what I mean:

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import StackingClassifier
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, stratify=y, random_state=42
)

X_train_1, X_train_2, y_train_1, y_train_2 = train_test_split(
    X_train, y_train, stratify=y_train, random_state=0, test_size=0.5,
)

X, y = load_iris(return_X_y=True)
estimators = [
    ('rf', RandomForestClassifier(n_estimators=10, random_state=42).fit(X_train_1, y_train_1)),
    ('svr', make_pipeline(StandardScaler(),
                          LinearSVC(random_state=42)).fit(X_train_1, y_train_1))
]

clf = StackingClassifier(
    estimators=estimators, final_estimator=LogisticRegression(), cv="prefit"
)

clf.fit(X_train, y_train)
print(f"Accuracy score: {clf.score(X_test, y_test):.2f}")

estimators[0][1].fit(X_train_2, y_train_2)

print(
    "Accuracy score after refitting an estimator: "
    f"{clf.score(X_test, y_test):.2f}"
)
Accuracy score: 0.89
Accuracy score after refitting an estimator: 0.95

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the norm for any object we pass into __init__ that is mutable such as dictionaries. We have not been semantic about deep copying because it comes with overhead.

As we discussed IRL, some may see your snippet as a feature as it behaves like Pipeline.

Logistically, given the status of freezing waiting for a freezing API meant his feature will be delayed quite a bit. I think we decided in the 2019 sprint that we were going with option 4: #8370 (comment) which means no freezing API. But it could be worth revisiting now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, we already have an issue with CalibratedClassifierCV. So I assume that we will be in trouble in the future but we might want to solve those as a whole.

@glemaitre
Copy link
Member

The PR looks good. We only need to make a decision regarding the API and to know if we do a deep copy or not.
I will add a tag such that we do a decision.

Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 then. @Micky774 Could you resolve the conflict such that we can merge.

@@ -311,6 +311,9 @@ Changelog
- |Enhancement| Adds support to use pre-fit models with `cv="prefit"`
in :class:`ensemble.StackingClassifier` and :class:`ensemble.StackingRegressor`.
:pr:`16748` by :user:`Siqi He <siqi-he>` and :pr:`22215` by
- |Enhancement| :class:`feature_selection.GenericUnivariateSelect` preserves
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that there is an issue with the merging here.

Copy link
Member
@thomasjpfan thomasjpfan Feb 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been seeing more weird merge issues lately in the changelog in other PRs. It could be related to #21516

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand what the problem is -- everything looks fine on my end?

Copy link
Member
@thomasjpfan thomasjpfan Feb 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine now. Maybe the GitHub interface was doing something weird with the diff.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was missing your username: 9d6c27a

Usually, it happens that GitHub complains about merge conflict. In this case, merging locally will not show any conflicts but git is actually messing up the merging as it did for this one. Basically, GitHub is right but I don't know why git is resolving this merge conflict on its own.

It only happens with the changelog generally.

@Micky774
Copy link
Contributor Author

Just wanted to ping again to see if this was ready for merge -- it has two approvals and should be caught-up to main

@thomasjpfan thomasjpfan merged commit 691972a into scikit-learn:main Feb 22, 2022
@Micky774 Micky774 deleted the allow_prefit_in_stacking branch February 22, 2022 19:22
thomasjpfan added a commit to thomasjpfan/scikit-learn that referenced this pull request Mar 1, 2022
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Siqi He <siqi.he@upstart.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Pre-fit Model to Stacking Model
4 participants
0