10000 FEA Callbacks base infrastructure + progress bars (Alternative to #27663) by jeremiedbb · Pull Request #28760 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FEA Callbacks base infrastructure + progress bars (Alternative to #27663) #28760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 78 commits into
base: callbacks
Choose a base branch
from

Conversation

jeremiedbb
Copy link
Member

Alternative to #27663 based on feedback from the drafting meeting. I'm keeping both open for now for easier comparison.

Copy link
github-actions bot commented Apr 3, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: def2687. Link to the linter CI: here

@jeremiedbb
Copy link
Member Author

@glemaitre and @adrinjalali I think this is ready for reviews. I implemented the changes discussed during some drafting meetings compared to #27663.

A quick summary.

From the point of view of a user:

  • callback objects are available from sklearn.callback. In this PR, only ProgressBar is implemented.
  • callbacks are registered to estimators using the set_callbacks(ProgressBar()).
  • if I want to enable the ProgressBar to all inner steps of my pipeline: ProgressBar(max_estimator_depth=None). Limited to depth=1 by default for performance considerations.

From the point of view of a third party developer of estimators:

  • Callback support is enabled through the CallbackSupportMixin
  • A CallbackContext object must be created at the beginning of fit (using init_callback_context from the mixin). This object is then used to create sub-contexts for subtasks, evaluate the callback hooks, and eventually propagate the callbacks to sub-estimators.

From the point of view of a third party developer of callbacks:

  • Callbacks must follow the CallbackProtocol protocol (3 hooks essentially).
  • Callbacks that should be propagated to inner estimators must follow an additional protocol.
  • The 3 hooks, receive different args. Their signature will be refined and better documented in follow-up PRs. One key arg is the node of the current task, which is a TaskNode object and is useful to find at which step the hook was called and take actions upon that.

@glemaitre glemaitre self-requested a review July 25, 2024 08:21
Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my recall from the meeting, it looks what we discussed. I did a pass in the test to have a feeling on the usage from different perspective.

For me this is looking good.

@adrinjalali do you want to have a look at it.

@@ -130,6 +130,10 @@ def _clone_parametrized(estimator, *, safe=True):

params_set = new_object.get_params(deep=False)

# attach callbacks to the new estimator
if hasattr(estimator, "_skl_callbacks"):
new_object._skl_callbacks = clone(estimator._skl_callbacks, safe=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading this line, it makes me think that we should have a test in the callback file to check that clone does what it is supposed to do. Basically, here we are making some assumptions and it would be great that our tests are checking those.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree but the only callback in this PR (progress bar) is never cloned because it's always propagated from an outer estimator. So it's not clear what the exact behavior we want for clone.

That's why I'd rather clear that out when implementing the other callbacks and add the appropriate tests then.
(Note that I still need to define a clone somehow currently for tests that check for error messages, independent of the exact behavior of clone though).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to have a test for equivalence of _skl_callbacks on objects in estimator_checks.py to make sure as we keep developing, things stay consistent, including third parties. I don't mind if that happens in a separate PR, before we merge into main

max_estimator_depth : int, default=1
The maximum number of nested levels of estimators to display progress bars for.
By default, only the progress bars of the outermost estimator are displayed.
If set to None, all levels are displayed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not in this PR but in the branch targetting main, we will need to more documentation. Here, we are missing the attribute and probably an example usage.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, examples will come in subsequent PRs. This param in particular is something that I'd like to improve though (see the AutoPropagatedProtocol). It can be improved in a subsequent PR as well, but definitely before merging in main.

Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a look at this, I like the API for third party developers, but two things here are missing for me:

  • documentation: I feel the documentation is quite sparse. I would like to be able to understand what classes are for what purpose, what's a task, etc.
  • API: I feel everything which is not a dev API is private here, but used as a public attribute all around, which makes me uneasy. There should be a better distinction between things only the class itself touches, and things which should be used by outsider classes (yet inside callback infra)

class CallbackContext:
"""Task level context for the callbacks.

This class is responsible for managing the callbacks and task tree of an estimator.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's a task tree?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added more doc in the docstring of the TaskNode class. This class is public and has some public methods because a task node is passed to the callback hooks and can be used by a callback developer to perform different operations depending on the task at hand.

@jeremiedbb
Copy link
Member Author

I feel everything which is not a dev API is private here, but used as a public attribute all around,

Can you give me specific examples, I'm not sure what you're talking about ?

class CallbackProtocol(Protocol):
"""Protocol for the callbacks"""

def _on_fit_begin(self, estimator, *, data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for instance, why are these private?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made these private because callback classes are public to the end user of scikit-learn but not its method which should only be implemented by callback developers and never called by anyone but scikit-learn internals.

I wanted to avoid these methods to appear on auto-completion in notebooks for instance.

It's hard to differentiate 3 levels of privacy (maintainers, third-party devs and end users) with a binary marker, leading underscore or not 😄

@jeremiedbb
Copy link
Member Author

I agree that CallbackContext is still missing documentation. However this is the main, and only really, object needed to implement callback support in an estimator. My goal is to make a subsequent PR with a detailed example on how to implement callback support in an estimator which will better explain what is and how to use the CallbackContext class than just adding more to the docstring. What do you think ? To me this where the focus on doc should be the highest.

Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Levels of abstraction, for third party devs at this point we do __sklearn_method_name__ pattern, if they're supposed to implement / override them to modify behavior.

I don't think we should use _method_name much, on things which are called outside the class itself (as in, really treat them as private). But to be able to suggest alternatives, I'd need to better understand this part of the codebase.

Could you maybe add a README file to the callback folder, explaining conceptually the overview of what each object / file is supposed to do? That makes it easier for me to review this, and understand where the scoping challenges are.

@@ -130,6 +130,10 @@ def _clone_parametrized(estimator, *, safe=True):

params_set = new_object.get_params(deep=False)

# attach callbacks to the new estimator
if hasattr(estimator, "_skl_callbacks"):
new_object._skl_callbacks = clone(estimator._skl_callbacks, safe=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to have a test for equivalence of _skl_callbacks on objects in estimator_checks.py to make sure as we keep developing, things stay consistent, including third parties. I don't mind if that happens in a separate PR, before we merge into main

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

3 participants
0