8000 ENH Add metadata routing to OneVsRestClassifier, OneVsOneClassifier and OutputCodeClassifier by StefanieSenger · Pull Request #27308 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Add metadata routing to OneVsRestClassifier, OneVsOneClassifier and OutputCodeClassifier #27308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Oct 2, 2023

Conversation

StefanieSenger
Copy link
Contributor
@StefanieSenger StefanieSenger commented Sep 6, 2023

Reference Issues/PRs

Towards #22893

What does this implement/fix? Explain your changes.

  • Adds metadata routing to OneVsRestClassifier, OneVsOneClassifier and OutputCodeClassifier
  • The routing is added in the fit and partial_fit methods if they exist
  • A test is added for checking that the coefficients are different when different sample weights are passed. Another test asserts if the correct error message is shown if enable_metadata_routing is False, but metadata is passed.

Comments

I had do put those three together, because they share common functions.

Thanks @glemaitre for your help!

@github-actions
Copy link
github-actions bot commented Sep 6, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: b1a8003. Link to the linter CI: here

Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add these classes to test_metaestimators_metadata_routing.py, and apply the same comments on the other instances in this PR.

Otherwise it's looking good.

Comment on lines 943 to 945
# TODO: should we raise an error if sample weight is requested but is not
# passed in the router's fit method when enable slep006 is True?
# clf.fit(iris.data, iris.target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see, we can remove the comment then.

We would need to settle on the behaviour at some point. I am a bit on the same side as @thomasjpfan, finding that this is weird not to raise anything (I can be happy with a warning).

Comment on lines 967 to 969
"""Assert if coefficients differ when estimators are partial_fitted with different
`sample_weights`."""
sample_weight = np.ones_like(iris.target, dtype=np.float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, the common tests should test this

Comment on lines 1004 to 1006
"""Test that the right error message is raised when metadata params
are passed while `enable_metadata_routing=False`."""
sample_weight = np.ones_like(iris.target, dtype=np.float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we need this, this machinary is tested in the base class level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, @glemaitre ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with removing the tests if it is tested at a "common" level. It was handy to have it to quickly develop.

Comment on lines 982 to 987
clf.partial_fit(
iris.data,
iris.target,
classes=np.unique(iris.target),
sample_weight=sample_weight,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Migrating this test towards sklearn/tests/test_metaestimators_metadata_routing.py isn't easily done.

Reason: When partial_fitting OneVsRestClassifier and OneVsOneClassifier an additional param (classes=np.unique(y)) needs to be passed, that the central tests don't account for (btw. I wonder why the function isn't internally doing that):

sklearn/utils/multiclass.py:410: ValueError: classes must be passed on the first call to partial_fit.

This error occurs when implementing the central test with routing_methods: partial_fit and the tests then call OneVsRestClassifier.partial_fit() without classes (as they are supposed to).

In test_setting_request_removes_error, should we add a condition checking for additional params before calling method(X, y, **kwargs)? But then, in the next step, the methods of the artificial estimators from sklearn/tests/metadata_routing_common.py (for instance ConsumingClassifier) need to be changed as well, so they would accept classes. I don't know how many meta-estimator's methods take additional params like classes, but wouldn't it be quite a lot?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good one. We should certainly fix our common tests to handle cases when partial_fit accepts classes and when not. This could be done by a try/except catching TypeError. As for adding classes to other partial_fit, seems like a good idea, I wonder why they don't support it already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will try that.

But what is the reason why partial_fit allows the user to define classes?

I cannot think of any other way than classes=np.unique(y) to define classes. Why / in what case would a user want to deviate from that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I made something for test_setting_request_on_sub_estimator_removes_error that looks like this:

            try:
                method(X, y, **method_kwargs)
            except ValueError:
                method(X, y, classes=np.unique(y), **method_kwargs)

And ConsumingClassifier from sklearn/tests/metadata_routing_common.py now feeds classes=None in its partial_fit method.

It works at least. Please let me know if this is correct, @adrinjalali

Copy link
Member
@adrinjalali adrinjalali Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot think of any other way than classes=np.unique(y) to define classes. Why / in what case would a user want to deviate from that?

Cause the user might not have all the classes present in a batch of data passed to partial_fit. There can be a case where the batch being passed to partial_fit includes only a subset of total classes, and the classes are pre-defined by the user rather than selected from each batch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do the code more as:

            try:
                method(X, y, classes=np.unique(y), **method_kwargs)
            except TypeError:
                method(X, y, **method_kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to do it like this, because MultiOutputClassifier.partial_fit() expects its classes in another format.

            try:
                method(X, y, classes=np.unique(y), **method_kwargs)
            except TypeError:
                method(X, y, **method_kwargs)
            except IndexError:
                method(X, y, **method_kwargs)

@StefanieSenger
Copy link
Contributor Author
StefanieSenger commented Sep 7, 2023

@adrinjalali, thank you for your review and the comments. :)

I've tried to migrate the tests to the central test file (sklearn/tests/test_metaestimators_metadata_routing.py), but due to unique characteristics of the multiclassifier's partial_fit methods and how OutputCodeClassifier creates its own sub-estimator in special cases, this is not possible without disruption.

I've tried to explain the reasons in the two comments and hope what I wrote is understandable.

Summary of what is done and not done:
Central test tests all the "fit" methods, but for OutputCodeClassifier not correctly, because it creates it's own sub_estimator sometimes.
Central test does NOT test "partial_fit", because it expects an additional param classes.

9E12

@@ -60,7 +60,7 @@ def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs):
if key in split_params and recorded_value is not None:
assert np.isin(recorded_value, value).all()
else:
assert recorded_value is value
assert np.allclose(recorded_value, value)
Copy link
Contributor Author
@StefanieSenger StefanieSenger Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to do this, because of a very slight mismatch when testing OneVsOneClassifier.fit() (because of floating-point precision?)

But as I now see in the CI this change has caused another test to fail (test_metadata_routing_for_column_transformer), so this is a catch. Should I revert that change? How else to deal with the failure, that I originally wanted to prevent?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please explain why this change was needed? the metadata shouldn't be touched by the meta-estimator, so there should be no numerical issues in terms of mismatch. This would only fix an issue if we were somehow creating a new object given the old one, but the whole old one, which might be the case if we're doing OVO on a binary classification case since we'd select the whole data.

The error you're getting in the CI now is:

E           numpy.exceptions.DTypePromotionError: The DType <class 'numpy._FloatAbstractDType'> could not be promoted by <class 'numpy.dtypes.StrDType'>. This means that no common DType exists for the given inputs. For example they cannot be stored in a single array unless the dtype is `object`. The full list of DTypes is: (<class 'numpy.dtypes.StrDType'>, <class 'numpy._FloatAbstractDType'>)

It's because this is now trying to compare non-numerical arrays which it cannot.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

Copy link
Contributor Author
@StefanieSenger StefanieSenger Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali I had tried this, because when running test_setting_request_on_sub_estimator_removes_error I had gotten the following error:

            if key in split_params and recorded_value is not None:
                assert np.isin(recorded_value, value).all()
            else:
>               assert recorded_value is value
E               AssertionError

Now I believe it's due to feeding only part of the data into the sub-estumator in _fit_ovo_binary. So only assert np.isin(recorded_value, value).all() should be checked, not assert recorded_value is value.

Setting "preserves_metadata": "subset", for ovo does this, but regarding your explanations here, I still don't understand, why partial_fit didn't raise the same issue, only fit did.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes sense since we're indexing the original array, and the indexing, even though it might select all rows, might be returning a new object (albeit a view to the old one), therefore is wouldn't work, and therefore we should set it as a subset in the tests. partial_fit might be not doing the sub-selection then.

Copy link
Contributor Author
@StefanieSenger StefanieSenger Sep 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, "preserves_metadata": "subset", it is then. :)

From what I understood, partial_fit also does a sub-selection. Could you look at the code to tell me if I'm wrong?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it does. I'm not sure if the test passes then. But I wouldn't worry about it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the test for partial_fit also passes assert recorded_value is value, which it shouldn't.

I will immediately stop worrying about this, but not before telling you that this odd test passing when it shouldn't had lead me totally off track while finding the reasons for fit not passing the same test.

The problem is, your code is so flexible, that it's both, hard to read and hard to document.

@@ -60,7 +60,7 @@ def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs):
if key in split_params and recorded_value is not None:
assert np.isin(recorded_value, value).all()
else:
assert recorded_value is value
assert np.allclose(recorded_value, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please explain why this change was needed? the metadata shouldn't be touched by the meta-estimator, so there should be no numerical issues in terms of mismatch. This would only fix an issue if we were somehow creating a new object given the old one, but the whole old one, which might be the case if we're doing OVO on a binary classification case since we'd select the whole data.

The error you're getting in the CI now is:

E           numpy.exceptions.DTypePromotionError: The DType <class 'numpy._FloatAbstractDType'> could not be promoted by <class 'numpy.dtypes.StrDType'>. This means that no common DType exists for the given inputs. For example they cannot be stored in a single array unless the dtype is `object`. The full list of DTypes is: (<class 'numpy.dtypes.StrDType'>, <class 'numpy._FloatAbstractDType'>)

It's because this is now trying to compare non-numerical arrays which it cannot.

Comment on lines 982 to 987
clf.partial_fit(
iris.data,
iris.target,
classes=np.unique(iris.target),
sample_weight=sample_weight,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good one. We should certainly fix our common tests to handle cases when partial_fit accepts classes and when not. This could be done by a try/except catching TypeError. As for adding classes to other partial_fit, seems like a good idea, I wonder why they don't support it already.

Comment on lines 224 to 234
- |Enhancement| :class:`multiclass.OneVsRestClassifier`,
:class:`multiclass.OneVsOneClassifier` and
:class:`multiclass.OutputCodeClassifier` now support metadata routing.
:meth:`multiclass.OneVsRestClassifier.fit`,
:meth:`multiclass.OneVsOneClassifier.fit` and
:meth:`multiclass.OutputCodeClassifier.fit` now accept ``**fit_params`` which
are passed to the estimator used.
:meth:`multiclass.OneVsRestClassifier.partial_fit` and
:meth:`multiclass.OneVsOneClassifier.partial_fit` now accept
``**partial_fit_params`` which are passed to the estimator used. :pr:`27308`
by :user:`Stefanie Senger <StefanieSenger>`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All such entries are now placed on top of the file together, you can move this up one there as well. You might need to merge with main to see that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali Do you mean #27421, which is not merged yet?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I mean #27386 which is merged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found it, thanks. Is adding metadata routing to estimators |Enhancement| or |Feature|, btw? Right now both is used and I've ordered alphabetically.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're gonna use |Feature| for the rest of them.

@@ -60,7 +60,7 @@ def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs):
if key in split_params and recorded_value is not None:
assert np.isin(recorded_value, value).all()
else:
assert recorded_value is value
assert np.allclose(recorded_value, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

Comment on lines 363 to 366
try:
method(X, y, **method_kwargs)
except ValueError:
method(X, y, classes=np.unique(y), **method_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could reverse the order here, and catch TypeError and always pass classes if it's in the signature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali Thanks for the suggestion. Could you please explain the advantages of reversing the order? I don't see it, unfortunately.

My thought was to try what covers the most cases first to potentially save computational resources. Vastly most meta-estimators don't have classes implemented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're basically checking if the signature of the function supports classes or not, if yes, we pass it. If the signature doesn't support it, we get a TypeError which is standard in Python. Relying on the ValueError is something in our code instead of standard python machinery.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relying on the ValueError is something in our code instead of standard python machinery.

I'm sorry, but I can't follow your reasoning and I'd like to learn and understand what you ask me to change. Why is turning it around to check the least common case first, better? Why is having TypeError more desirable than ValueError?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One is raised by Python machinery, one is raised by our code, and other estimator might raise a different error than ValueError. Here relying on python's machinery is much more resilient and reliable than relying on code that we write in our estimators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now makes sense to me. 💡 Thank you. I will change it with the next push later today.

@adrinjalali
Copy link
Member

@StefanieSenger please let me know when you want a new review, and summarize remaining points since it's hard to find the ongoing discussions in the comments at this point.

@StefanieSenger
Copy link
Contributor Author
StefanieSenger commented Sep 22, 2023

@adrinjalali It's ready for a second review. The implementation and the testing are both complete as far as I am aware of and the CI is now green.

I believe that the tests in test_multiclass can now be entirely deleted. But we were waiting for @glemaitre s approval, because (before I added the multiclassifiers to the central test file for routing) he had advised me to put them.

Apart from this, there are no open questions from my side, and in case you (or someone) wonders about something, I will try to explain my reasoning, so we can find out if it holds.

Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the individual tests can now be removed since we're testing them in the common tests.

Comment on lines 39 to 48
- |Enhancement| :class:`~compose.ColumnTransformer` now supports metadata routing
according to :ref:`metadata routing user guide <metadata_routing>`. :pr:`27005`
by `Adrin Jalali`_.

- |Enhancement| :class:`linear_model.LogisticRegressionCV` now supports
metadata routing. :meth:`linear_model.LogisticRegressionCV.fit` now
accepts ``**params`` which are passed to the underlying splitter and
scorer. :meth:`linear_model.LogisticRegressionCV.score` now accepts
``**score_params`` which are passed to the underlying scorer.
:pr:`26525` by :user:`Omar Salman <OmarManzoor>`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two entries are duplicates of bellow, not sure why they're repeated here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I believe it got in on an automated merging main into my branch. Deleting it here.

Comment on lines 38 to 43
y = rng.randint(0, 2, size=N)
y = rng.randint(0, 3, size=N)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have two ys here, y and y_multi, since I'm not sure if all estimators support multiclass. We should pass y_multi only to the ones that support it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All estimators support multiclass: https://scikit-learn.org/stable/modules/multiclass.html#multiclass-classification

This said, we also have a binary_only tag in the estimator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali @glemaitre Please, could you tell me whether I should work on passing two different y's?

I have tried to find out about estimator tags and I couldn't find binary_only implemented anywhere except for in the tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, you would have something like:

from sklearn.utils._tags import _safe_tags

tags = _safe_tags(estimator)
if tags["binary_onlty"]:
    estimator.fit(X, y_binary)
else:
    estimator.fit(X, y_multiclass)

where y_binary and y_multiclass would be the created target at the beginning of the file.

So I foresee 2 solutions.

Create a ConsumerBinaryClassifier such that:

class ConsumerBinaryClassifier(ConsumerClassifier):
    def fit(self, X, y, **fit_params):
        y_type = type_of_target(y)
        if y_type != "binary":
            raise ValueError("This classifier supports only binary target")
        return super().fit(X, y, fit_params)
    def _more_tags(self):
        return {"binary_only": True}

Otherwise, we go from the principle that there are no binary classification and we just add a # TODO comment to modify the code if one day it happens.

@adrinjalali what do you think is best.

Comment on lines 367 to 370
except TypeError:
method(X, y, **method_kwargs)
except IndexError:
method(X, y, **method_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't make sense, you're calling the method twice with the same args, why wouldn't it nor raise again?

Copy link
Contributor Author
@StefanieSenger StefanieSenger Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was, because MultiOutputClassifier.partial_fit() expects classes in another format. It didn't raise a TypeError, but an IndexError instead.

@glemaitre glemaitre self-requested a review September 25, 2023 09:23
Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of comments but it looks good already. Nice work @StefanieSenger @adrinjalali

Comment on lines 1004 to 1006
"""Test that the right error message is raised when metadata params
are passed while `enable_metadata_routing=False`."""
sample_weight = np.ones_like(iris.target, dtype=np.float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with removing the tests if it is tested at a "common" level. It was handy to have it to quickly develop.

Comment on lines 213 to 216
if classes is not None:
self.classes_ = classes
else:
self.classes_ = np.unique(y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not look like the pattern that we are using classifier. Usually, we will check that classes_ already exist (is or not the first call to partial_fit), if this is the first call then we create classes_ using classes that should not be None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but this is a test case, I don't mind having the right thing, but I don't think it matters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali and @glemaitre Please let me know what you prefer and then I can try.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After talking with @adrinjalali, I have substitited this part with _check_partial_fit_first_call(self, classes), because this does everything required (check if classes_ exist, if not using classes for it) , also what is done partial_fit of MultiOutputClassifier, which then had to be fed with initial classes, so that it can pass this check.

Comment on lines 38 to 43
y = rng.randint(0, 2, size=N)
y = rng.randint(0, 3, size=N)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All estimators support multiclass: https://scikit-learn.org/stable/modules/multiclass.html#multiclass-classification

This said, we also have a binary_only tag in the estimator.

Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @StefanieSenger

@glemaitre I think you can give it another review now.

@glemaitre glemaitre self-requested a review October 2, 2023 13:43
Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on my side as well.

Kudos @StefanieSenger. This was not as straightforwards (as usual) with the partial_fit business.

@glemaitre glemaitre enabled auto-merge (squash) October 2, 2023 13:48
@glemaitre glemaitre merged commit 3260a93 into scikit-learn:main Oct 2, 2023
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
…nd OutputCodeClassifier (scikit-learn#27308)

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
@StefanieSenger StefanieSenger deleted the routing1 branch April 18, 2024 10:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0