[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH make fit_transform and fit_predict composite methods (SLEP6) #26506

Merged
merged 15 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ Thanks to everyone who has contributed to the maintenance and improvement of
the project since version 1.3, including:

TODO: update at the time of the release.

:mod:`sklearn.base`
...................

- |Enhancement| :func:`base.ClusterMixin.fit_predict` and
:func:`base.OutlierMixin.fit_predict` now accept ``**kwargs`` which are
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
passed to the ``fit`` method of the the estimator. :pr:`26506` by `Adrin
Jalali`_.
18 changes: 14 additions & 4 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ class ClusterMixin:

_estimator_type = "clusterer"

def fit_predict(self, X, y=None):
def fit_predict(self, X, y=None, **kwargs):
"""
Perform clustering on `X` and returns cluster labels.

Expand All @@ -777,14 +777,19 @@ def fit_predict(self, X, y=None):
y : Ignored
Not used, present for API consistency by convention.

**kwargs : dict
Arguments to be passed to ``fit``.

.. versionadded:: 1.4

Returns
-------
labels : ndarray of shape (n_samples,), dtype=np.int64
Cluster labels.
"""
# non-optimized default implementation; override when a better
# method is possible for a given clustering algorithm
self.fit(X)
self.fit(X, **kwargs)
return self.labels_

def _more_tags(self):
Expand Down Expand Up @@ -1010,7 +1015,7 @@ class OutlierMixin:

_estimator_type = "outlier_detector"

def fit_predict(self, X, y=None):
def fit_predict(self, X, y=None, **kwargs):
"""Perform fit on X and returns labels for X.

Returns -1 for outliers and 1 for inliers.
Expand All @@ -1023,13 +1028,18 @@ def fit_predict(self, X, y=None):
y : Ignored
Not used, present for API consistency by convention.

**kwargs : dict
Arguments to be passed to ``fit``.

.. versionadded:: 1.4

Returns
-------
y : ndarray of shape (n_samples,)
1 for inliers, -1 for outliers.
"""
# override for transductive outlier detectors like LocalOulierFactor
return self.fit(X).predict(X)
return self.fit(X, **kwargs).predict(X)


class MetaEstimatorMixin:
Expand Down
81 changes: 66 additions & 15 deletions sklearn/tests/test_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from sklearn.utils.metadata_routing import MetadataRouter
from sklearn.utils.metadata_routing import MethodMapping
from sklearn.utils.metadata_routing import process_routing
from sklearn.utils._metadata_requests import MethodPair
from sklearn.utils._metadata_requests import MethodMetadataRequest
from sklearn.utils._metadata_requests import _MetadataRequester
from sklearn.utils._metadata_requests import METHODS
from sklearn.utils._metadata_requests import METHODS, SIMPLE_METHODS, COMPOSITE_METHODS
from sklearn.utils._metadata_requests import request_is_alias
from sklearn.utils._metadata_requests import request_is_valid

Expand Down Expand Up @@ -58,7 +59,7 @@ def assert_request_is_empty(metadata_request, exclude=None):
return

exclude = [] if exclude is None else exclude
for method in METHODS:
for method in SIMPLE_METHODS:
if method in exclude:
continue
mmr = getattr(metadata_request, method)
Expand All @@ -75,7 +76,7 @@ def assert_request_equal(request, dictionary):
mmr = getattr(request, method)
assert mmr.requests == requests

empty_methods = [method for method in METHODS if method not in dictionary]
empty_methods = [method for method in SIMPLE_METHODS if method not in dictionary]
for method in empty_methods:
assert not len(getattr(request, method).requests)

Expand Down Expand Up @@ -819,17 +820,9 @@ def test_methodmapping():
assert mm_list[1] == ("fit", "fit")

mm = MethodMapping.from_str("one-to-one")
assert (
str(mm)
== "[{'callee': 'fit', 'caller': 'fit'}, {'callee': 'partial_fit', 'caller':"
" 'partial_fit'}, {'callee': 'predict', 'caller': 'predict'}, {'callee':"
" 'predict_proba', 'caller': 'predict_proba'}, {'callee':"
" 'predict_log_proba', 'caller': 'predict_log_proba'}, {'callee':"
" 'decision_function', 'caller': 'decision_function'}, {'callee': 'score',"
" 'caller': 'score'}, {'callee': 'split', 'caller': 'split'}, {'callee':"
" 'transform', 'caller': 'transform'}, {'callee': 'inverse_transform',"
" 'caller': 'inverse_transform'}]"
)
for method in METHODS:
assert MethodPair(method, method) in mm._routes
assert len(mm._routes) == len(METHODS)

mm = MethodMapping.from_str("score")
assert repr(mm) == "[{'callee': 'score', 'caller': 'score'}]"
Expand Down Expand Up @@ -944,6 +937,12 @@ class SimpleEstimator(BaseEstimator):
def fit(self, X, y):
pass # pragma: no cover

def fit_transform(self, X, y):
pass # pragma: no cover

def fit_predict(self, X, y):
pass # pragma: no cover

def partial_fit(self, X, y):
pass # pragma: no cover

Expand Down Expand Up @@ -979,6 +978,12 @@ class SimpleEstimator(BaseEstimator):
def fit(self, X, y, sample_weight=None):
pass # pragma: no cover

def fit_transform(self, X, y, sample_weight=None):
pass # pragma: no cover

def fit_predict(self, X, y, sample_weight=None):
pass # pragma: no cover

def partial_fit(self, X, y, sample_weight=None):
pass # pragma: no cover

Expand Down Expand Up @@ -1006,10 +1011,56 @@ def transform(self, X, sample_weight=None):
def inverse_transform(self, X, sample_weight=None):
pass # pragma: no cover

for method in METHODS:
# composite methods shouldn't have a corresponding set method.
for method in COMPOSITE_METHODS:
assert not hasattr(SimpleEstimator(), f"set_{method}_request")

# simple methods should have a corresponding set method.
for method in SIMPLE_METHODS:
assert hasattr(SimpleEstimator(), f"set_{method}_request")


def test_composite_methods():
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
class SimpleEstimator(BaseEstimator):
# This class should have every set_{method}_request
def fit(self, X, y, foo=None, bar=None):
pass # pragma: no cover

def predict(self, X, foo=None, bar=None):
pass # pragma: no cover

def transform(self, X, other_param=None):
pass # pragma: no cover

est = SimpleEstimator()
est.get_metadata_routing().fit_transform.requests == {}
est.get_metadata_routing().fit_predict.requests == {}
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved

# setting the request on only one of them should raise an error
est.set_fit_request(foo=True, bar="test")
with pytest.raises(ValueError, match="Conflicting metadata requests for"):
est.get_metadata_routing().fit_predict

# setting the request on the other one should fail if not the same as the
# first method
est.set_predict_request(bar=True)
with pytest.raises(ValueError, match="Conflicting metadata requests for"):
est.get_metadata_routing().fit_predict

# now the requests are consistent
est.set_predict_request(foo=True, bar="test")
est.get_metadata_routing().fit_predict

# setting the request for a none-overlapping parameter would merge them
# together.
est.set_transform_request(other_param=True)
assert est.get_metadata_routing().fit_transform.requests == {
"bar": "test",
"foo": True,
"other_param": True,
}


def test_no_feature_flag_raises_error():
"""Test that when feature flag disabled, set_{method}_requests raises."""
with config_context(enable_metadata_routing=False):
Expand Down
58 changes: 51 additions & 7 deletions sklearn/utils/_metadata_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@

# Only the following methods are supported in the routing mechanism. Adding new
# methods at the moment involves monkeypatching this list.
METHODS = [
SIMPLE_METHODS = [
"fit",
"partial_fit",
"predict",
Expand All @@ -102,6 +102,16 @@
"inverse_transform",
]

# These methods are a composite of other methods and one cannot set their
# requests directly. Instead they should be set by setting the requests of the
# simple methods which make the composite ones.
COMPOSITE_METHODS = {
"fit_transform": ["fit", "transform"],
"fit_predict": ["fit", "predict"],
}

METHODS = SIMPLE_METHODS + list(COMPOSITE_METHODS.keys())


def _routing_enabled():
"""Return whether metadata routing is enabled.
Expand Down Expand Up @@ -195,10 +205,13 @@ class MethodMetadataRequest:

method : str
The name of the method to which these requests belong.

requests : dict of {str: bool, None or str}, default=None
The initial requests for this method.
"""

def __init__(self, owner, method):
self._requests = dict()
def __init__(self, owner, method, requests=None):
self._requests = requests or dict()
self.owner = owner
self.method = method

Expand Down Expand Up @@ -383,13 +396,44 @@ class MetadataRequest:
_type = "metadata_request"

def __init__(self, owner):
for method in METHODS:
self.owner = owner
for method in SIMPLE_METHODS:
setattr(
self,
method,
MethodMetadataRequest(owner=owner, method=method),
)

def __getattr__(self, name):
# Called when the default attribute access fails with an AttributeError
# (either __getattribute__() raises an AttributeError because name is
# not an instance attribute or an attribute in the class tree for self;
# or __get__() of a name property raises AttributeError). This method
# should either return the (computed) attribute value or raise an
# AttributeError exception.
# https://docs.python.org/3/reference/datamodel.html#object.__getattr__
if name not in COMPOSITE_METHODS:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)

requests = {}
for method in COMPOSITE_METHODS[name]:
mmr = getattr(self, method)
existing = set(requests.keys())
upcoming = set(mmr.requests.keys())
common = existing & upcoming
conflicts = [key for key in common if requests[key] != mmr._requests[key]]
if conflicts:
raise ValueError(
f"Conflicting metadata requests for {', '.join(conflicts)} while"
f" composing the requests for {name}. Metadata with the same name"
f" for methods {', '.join(COMPOSITE_METHODS[name])} should have the"
" same request value."
)
requests.update(mmr._requests)
return MethodMetadataRequest(owner=self.owner, method=name, requests=requests)

def _get_param_names(self, method, return_alias, ignore_self_request=None):
"""Get names of all metadata that can be consumed or routed by specified \
method.
Expand Down Expand Up @@ -463,7 +507,7 @@ def _serialize(self):
A serialized version of the instance in the form of a dictionary.
"""
output = dict()
for method in METHODS:
for method in SIMPLE_METHODS:
mmr = getattr(self, method)
if len(mmr.requests):
output[method] = mmr._serialize()
Expand Down Expand Up @@ -1121,7 +1165,7 @@ def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
return

for method in METHODS:
for method in SIMPLE_METHODS:
mmr = getattr(requests, method)
# set ``set_{method}_request``` methods
if not len(mmr.requests):
Expand Down Expand Up @@ -1181,7 +1225,7 @@ class attributes, as well as determining request keys from method
"""
requests = MetadataRequest(owner=cls.__name__)

for method in METHODS:
for method in SIMPLE_METHODS:
setattr(
requests,
method,
Expand Down