8000 [WIP] ENH Adds caching to multimetric scoring with a wrapper class by thomasjpfan · Pull Request #14261 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] ENH Adds caching to multimetric scoring with a wrapper class #14261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

thomasjpfan
Copy link
Member

Reference Issues/PRs

Resolves #10802
Alternative to #10979

What does this implement/fix? Explain your changes.

Wraps the estimator, and caches the results of predict, predict_proba, decision_function, and score during _multimetric_score.

@thomasjpfan thomasjpfan changed the title [MRG] ENH Adds caching by wrapping to multimetric scoring [MRG] ENH Adds caching to multimetric scoring by wrapping Jul 5, 2019
@thomasjpfan thomasjpfan changed the title [MRG] ENH Adds caching to multimetric scoring by wrapping [MRG] ENH Adds caching to multimetric scoring with a context manager Jul 5, 2019
@jnothman
Copy link
Member
jnothman commented Jul 5, 2019

This is a nice solution, I think. Memory could be a small problem, but unlikely. This assumes all estimators have __dict__ and if we are to rely on that we need to check it in estimator checks

Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking pretty good. Add what's new, please

cache = {}
names = ['predict', 'predict_proba', 'decision_function', 'score']
cache_funcs = {name: getattr(estimator, name) for name in names if
hasattr(estimator, name)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well also check that it is callable??


cache = {}
names = ['predict', 'predict_proba', 'decision_function', 'score']
cache_funcs = {name: getattr(estimator, name) for name in names if
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe name it orig_funcs

@codecov
Copy link
codecov bot commented Jul 5, 2019

Codecov Report

Merging #14261 into master will decrease coverage by 0.06%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master   #14261      +/-   ##
==========================================
- Coverage   96.82%   96.75%   -0.07%     
==========================================
  Files         394      394              
  Lines       71916    71949      +33     
  Branches     7904     7906       +2     
==========================================
- Hits        69630    69616      -14     
- Misses       2263     2315      +52     
+ Partials       23       18       -5
Impacted Files Coverage Δ
sklearn/model_selection/tests/test_validation.py 99.04% <100%> (+0.01%) ⬆️
sklearn/model_selection/_validation.py 98.27% <100%> (+0.11%) ⬆️
sklearn/utils/fixes.py 39.37% <0%> (-27.56%) ⬇️
sklearn/ensemble/_hist_gradient_boosting/loss.py 97.61% <0%> (-2.39%) ⬇️
sklearn/manifold/spectral_embedding_.py 87.65% <0%> (-1.86%) ⬇️
sklearn/_build_utils/__init__.py 55.35% <0%> (-1.79%) ⬇️
sklearn/neighbors/tests/test_dist_metrics.py 97.5% <0%> (-0.84%) ⬇️
sklearn/linear_model/ridge.py 97.83% <0%> (-0.73%) ⬇️
sklearn/impute/_iterative.py 97.51% <0%> (-0.5%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ec6d0bb...6ed3613. Read the comment docs.

@thomasjpfan
Copy link
Member Author
thomasjpfan commented Jul 5, 2019

This assumes all estimators have dict

Is there a way to write an estimator without a __dict__?

Edit: joblib makes things not writable

@thomasjpfan thomasjpfan force-pushed the multimetric_caching branch from 6ed3613 to ae8d7a4 Compare July 5, 2019 13:30
@thomasjpfan thomasjpfan changed the title [MRG] ENH Adds caching to multimetric scoring with a context manager [WIP] ENH Adds caching to multimetric scoring with a context manager Jul 5, 2019
@thomasjpfan
Copy link
Member Author

I am unhappy with the two approaches I tried:

  1. Using a context manager with setattr will fail when used in joblib.
  2. Using a wrapper class which would not work with custom scorers. We could check if a scorer is a _PredictScorer, etc. and only wrap the estimator in those cases.

@thomasjpfan
Copy link
Member Author

This PR updated with:

  1. Only cache if scorers is a _BaseScorer, i.e. created with make_scorer.

@thomasjpfan thomasjpfan changed the title [WIP] ENH Adds caching to multimetric scoring with a context manager [MRG] ENH Adds caching to multimetric scoring with a class wrapper Jul 5, 2019
@thomasjpfan thomasjpfan changed the title [MRG] ENH Adds caching to multimetric scoring with a class wrapper [MRG] ENH Adds caching to multimetric scoring with a wrapper class Jul 5, 2019
@amueller
Copy link
Member
amueller commented Jul 5, 2019

For some reason I feel iffy about the cache not checking X, even though X_test is the same for the lifetime of the estimator.
It wouldn't hurt to store a reference to X and check object identity, right?

Generally I would kind of prefer not to use these hacks and expand the scorer interface to allow returning dicts of scores so we can do The Right Thing™ directly. That would also allow us to use precision_recall_fscore_support etc, which is now still called twice if you want precision and recall.

@jnothman
Copy link
Member
jnothman commented Jul 6, 2019

I'm happy with investigating the scorer returning dict solution, but I am not sure that will quickly remedy the present efficiency problem. Efficiently calculating per-class and micro average prf would require some rewriting of the metrics anyway...

@jnothman
Copy link
Member
jnothman commented Jul 6, 2019 via email

@thomasjpfan thomasjpfan changed the title [MRG] ENH Adds caching to multimetric scoring with a wrapper class [WIP] ENH Adds caching to multimetric scoring with a wrapper class Jul 6, 2019
@thomasjpfan
Copy link
Member Author

@jnothman Here is the error from test_gridsearch:

        # patch methods
        for name, func in cache_funcs.items():
            setattr(estimator, name,
>                   partial(_call_func, name=name, func=func, cache=cache))
E           AttributeError: can't set attribute

@thomasjpfan
Copy link
Member Author

The issue is how predict_proba is a @property in Voting*. The following will raise an AttributeError:

class Hello:
    @property
    def hello():
        return "world"

h = Hello()
setattr(h, "hello", 1)

@jnothman
Copy link
Member
jnothman commented Jul 6, 2019 via email

@thomasjpfan
Copy link
Member Author

Does not seem to work. To get this to work, we would need to define a setter for the property:

class MyObj:
    def _hello(sefl):
        return "world"

    @property
    def hello(self):
        return self._hello

    @hello.setter
    def hello(self, func):
        self._hello = func


def another_hello():
    return "a whole new world"


obj = MyObj()
old_func = obj.hello
print("before setattr", obj.hello())

setattr(obj, "hello", another_hello)
print("after setattr", obj.hello())

setattr(obj, "hello", old_func)
print("restore", obj.hello())

@thomasjpfan
Copy link
Member Author

The issue is not with the @property, but with how VotingClassifer was not duck typing correctly. I opened #14287 to address this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-metric scoring is incredibly slow because it repeats predictions for every metric
3 participants
0