-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FEA Callbacks base infrastructure + progress bars (Alternative to #27663) #28760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: callbacks
Are you sure you want to change the base?
Conversation
@glemaitre and @adrinjalali I think this is ready for reviews. I implemented the changes discussed during some drafting meetings compared to #27663. A quick summary. From the point of view of a user:
From the point of view of a third party developer of estimators:
From the point of view of a third party developer of callbacks:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my recall from the meeting, it looks what we discussed. I did a pass in the test to have a feeling on the usage from different perspective.
For me this is looking good.
@adrinjalali do you want to have a look at it.
@@ -130,6 +130,10 @@ def _clone_parametrized(estimator, *, safe=True): | |||
|
|||
params_set = new_object.get_params(deep=False) | |||
|
|||
# attach callbacks to the new estimator | |||
if hasattr(estimator, "_skl_callbacks"): | |||
new_object._skl_callbacks = clone(estimator._skl_callbacks, safe=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reading this line, it makes me think that we should have a test in the callback file to check that clone
does what it is supposed to do. Basically, here we are making some assumptions and it would be great that our tests are checking those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree but the only callback in this PR (progress bar) is never cloned because it's always propagated from an outer estimator. So it's not clear what the exact behavior we want for clone.
That's why I'd rather clear that out when implementing the other callbacks and add the appropriate tests then.
(Note that I still need to define a clone somehow currently for tests that check for error messages, independent of the exact behavior of clone though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes sense to have a test for equivalence of _skl_callbacks
on objects in estimator_checks.py
to make sure as we keep developing, things stay consistent, including third parties. I don't mind if that happens in a separate PR, before we merge into main
max_estimator_depth : int, default=1 | ||
The maximum number of nested levels of estimators to display progress bars for. | ||
By default, only the progress bars of the outermost estimator are displayed. | ||
If set to None, all levels are displayed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe not in this PR but in the branch targetting main
, we will need to more documentation. Here, we are missing the attribute and probably an example usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, examples will come in subsequent PRs. This param in particular is something that I'd like to improve though (see the AutoPropagatedProtocol
). It can be improved in a subsequent PR as well, but definitely before merging in main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a look at this, I like the API for third party developers, but two things here are missing for me:
- documentation: I feel the documentation is quite sparse. I would like to be able to understand what classes are for what purpose, what's a task, etc.
- API: I feel everything which is not a dev API is private here, but used as a public attribute all around, which makes me uneasy. There should be a better distinction between things only the class itself touches, and things which should be used by outsider classes (yet inside callback infra)
class CallbackContext: | ||
"""Task level context for the callbacks. | ||
|
||
This class is responsible for managing the callbacks and task tree of an estimator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's a task tree?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added more doc in the docstring of the TaskNode
class. This class is public and has some public methods because a task node is passed to the callback hooks and can be used by a callback developer to perform different operations depending on the task at hand.
Can you give me specific examples, I'm not sure what you're talking about ? |
class CallbackProtocol(Protocol): | ||
"""Protocol for the callbacks""" | ||
|
||
def _on_fit_begin(self, estimator, *, data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for instance, why are these private?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made these private because callback classes are public to the end user of scikit-learn but not its method which should only be implemented by callback developers and never called by anyone but scikit-learn internals.
I wanted to avoid these methods to appear on auto-completion in notebooks for instance.
It's hard to differentiate 3 levels of privacy (maintainers, third-party devs and end users) with a binary marker, leading underscore or not 😄
I agree that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Levels of abstraction, for third party devs at this point we do __sklearn_method_name__
pattern, if they're supposed to implement / override them to modify behavior.
I don't think we should use _method_name
much, on things which are called outside the class itself (as in, really treat them as private). But to be able to suggest alternatives, I'd need to better understand this part of the codebase.
Could you maybe add a README file to the callback folder, explaining conceptually the overview of what each object / file is supposed to do? That makes it easier for me to review this, and understand where the scoping challenges are.
@@ -130,6 +130,10 @@ def _clone_parametrized(estimator, *, safe=True): | |||
|
|||
params_set = new_object.get_params(deep=False) | |||
|
|||
# attach callbacks to the new estimator | |||
if hasattr(estimator, "_skl_callbacks"): | |||
new_object._skl_callbacks = clone(estimator._skl_callbacks, safe=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes sense to have a test for equivalence of _skl_callbacks
on objects in estimator_checks.py
to make sure as we keep developing, things stay consistent, including third parties. I don't mind if that happens in a separate PR, before we merge into main
Alternative to #27663 based on feedback from the drafting meeting. I'm keeping both open for now for easier comparison.