8000 Add a way to use `warm_start` together with `cross_val_score`? · Issue #22044 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Add a way to use warm_start together with cross_val_score? #22044
Open
@PGijsbers

Description

@PGijsbers

Discussed in #22042

Originally posted by PGijsbers December 21, 2021
I want to obtain scores for a random forest with a different number of estimators across the same fold in a reproducible manner. I essentially want something like:

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score

x, y = make_classification(random_state=0)
rf = RandomForestClassifier(n_estimators=10, warm_start=True, random_state=0)
step_size = 10

for i in range(3):
    scores = cross_val_score(rf, x, y, cv=3)
    print(scores)
    rf.n_estimators += step_size

However, from my understanding each time the forests are trained from scratch in this scenario, since the original rf object is cloned. Instead I would want to save the progress for each fold and use that in the next iteration. It is possible to get the fitted estimators with cross_validate(..., return_estimator=True) but I don't see any compatible method to pass an estimator for each fold.
Is there a way? Or should I instead open an issue for a feature request? :)


Suggested change: allow cross_validate's estimator parameter to optionally be a list of estimators which matches in length to the number of folds. Then from the user it could work something similar to:

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_validate

x, y = make_classification(random_state=0)
rf = RandomForestClassifier(n_estimators=10, warm_start=True, random_state=0)
k_folds = 3
forests = [rf] * k_folds
step_size = 10

for i in range(3):
   result = cross_validate(forests, x, y, cv=k_folds, return_estimator=True)
   forests = result["estimator"]
   for forest in forests:
       forest.n_estimators += step_size

Internally you can presumably just zip the estimators in here and be done as I assume that the returned estimator list is already in matching order with the folds.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0