8000 `prefit` option missing for voting? · Issue #12297 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

prefit option missing for voting? #12297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tengerye opened this issue Oct 5, 2018 · 10 comments · May be fixed by #28434
Open

prefit option missing for voting? #12297

tengerye opened this issue Oct 5, 2018 · 10 comments · May be fixed by #28434
Labels
Moderate Anything that requires some knowledge of conventions and best practices module:ensemble

Comments

@tengerye
Copy link
tengerye commented Oct 5, 2018

I work on python 3.6.5 and sklearn 0.19.1 on conda 4.5.11.

According to this post, I am surprised that voting classifiers do not accept prefix classifiers.

May I know is it because it is unnecessary from your respective or it was purely short of hand?
Thank you very much.

@tengerye
Copy link
Author
tengerye commented Oct 5, 2018

My case is I want to calibrate base classifiers and then perform the soft voting. Here is my codes:

    data = load_breast_cancer()

    # Data spliting.
    X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2)
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25)

    # Base classifiers.
    clf_svm = svm.SVC(gamma=0.001, probability=True)
    clf_svm.fit(X_train, y_train)

    clf_lr = LogisticRegression(random_state=0, solver='lbfgs')
    clf_lr.fit(X_train, y_train)

    svm_isotonic = CalibratedClassifierCV(clf_svm, cv='prefit', method='isotonic')
    svm_isotonic.fit(X_val, y_val)

    lr_isotonic = CalibratedClassifierCV(clf_lr, cv='prefit', method='isotonic')
    lr_isotonic.fit(X_val, y_val)

    eclf_soft2 = VotingClassifier(estimators=[
        ('svm', svm_isotonic), ('lr', lr_isotonic)], voting ='soft')
    eclf_soft2.fit(X_val, y_val)

Now it throws exception:

Traceback (most recent call last):
  File "/home/ubuntu/projects/faceRecognition/faceVerif/util/plot_calibration.py", line 125, in <module>
    main(parse_arguments(sys.argv[1:]))
  File "/home/ubuntu/projects/faceRecognition/faceVerif/util/plot_calibration.py", line 120, in main
    eclf_soft2.fit(X_val, y_val)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/ensemble/voting_classifier.py", line 189, in fit
    for clf in clfs if clf is not None)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py", line 779, in __call__
    while self.dispatch_one_batch(iterator):
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py", line 625, in dispatch_one_batch
    self._dispatch(tasks)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py", line 588, in _dispatch
    job = self._backend.apply_async(batch, callback=cb)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 111, in apply_async
    result = ImmediateResult(func)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 332, in __init__
    self.results = batch()
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py", line 131, in __call__
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py", line 131, in <listcomp>
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/ensemble/voting_classifier.py", line 31, in _parallel_fit_estimator
    estimator.fit(X, y)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/calibration.py", line 157, in fit
    calibrated_classifier.fit(X, y)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/calibration.py", line 335, in fit
    df, idx_pos_class = self._preproc(X)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/calibration.py", line 290, in _preproc
    df = self.base_estimator.decision_function(X)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/svm/base.py", line 527, in decision_function
    dec = self._decision_function(X)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/svm/base.py", line 384, in _decision_function
    X = self._validate_for_predict(X)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/svm/base.py", line 437, in _validate_for_predict
    check_is_fitted(self, 'support_')
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/utils/validation.py", line 768, in check_is_fitted
    raise NotFittedError(msg % {'name': type(estimator).__name__})
sklearn.exceptions.NotFittedError: This SVC instance
8000
 is not fitted yet. Call 'fit' with appropriate arguments before using this method.

If I could perform soft voting with prefit, I think everything will be fine.
Thank you in advance.

@amueller
Copy link
Member
amueller commented Oct 5, 2018

I forgot what the exact discussion was, but I now I tend to agree that adding this option would be useful.
I think my argument used to be that it's very easy to implement yourself, but I don't think that's a great argument.
There will be issues with prefit here as in #8370 and #6451 but that shouldn't prevent it from being useful in some cases.

@tengerye
Copy link
Author
tengerye commented Oct 6, 2018

Great, glad to see you have a plan. I will discuss on those issues. Many thanks.@amueller

@tengerye tengerye closed this as completed Oct 6, 2018
@amueller
Copy link
Member
amueller commented Oct 7, 2018

I think we should keep this open as it's somewhat different.

@amueller amueller reopened this Oct 7, 2018
@thomasjpfan thomasjpfan added Sprint Moderate Anything that requires some knowledge of conventions and best practices labels Aug 20, 2019
@kwenlyou
Copy link

Another scenario is that some models are expensive to get. They could be trained by using something like dask. In the ensemble stage, it is not necessary and is time-consuming to train them again.

Btw, is there any plan to resolve this issue recently? Thanks.

@aliechoes
Copy link

is there any update on this thread?

@GO-Loc-GO
Copy link

Looking forward to an update on this issue, I'm working on a project currently and this feature would have come in handy for my experiments, pity it's not available.

@vyasprateek
Copy link

Its been open for 5 yrs and is required feature by many, is there any plan of working on this ?

@lizhuoq
Copy link
lizhuoq commented Jan 12, 2024

I find this feature very useful and look forward to future updates.

@eddiebergman
Copy link
Contributor
eddiebergman commented Feb 16, 2024

I would attempt this feature as a requirement for a library that builds on top of sklearn. I can produce a naive first attempt but then will need guidance on other non-obvious aspects that need to be considered, such as #6451 and #8370 as mentioned by @amueller above. If any of these issues needs to be addressed first, please let me know.

My current workarounds for this problem is directly manipulating attributes_ or having custom classes but as a feature directly supported by sklearn, it would be much more native and exportable.

Please let me know if any core dev would be happy to mentor through a draft PR I've created at #28434

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Moderate Anything that requires some knowledge of conventions and best practices module:ensemble
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants
0