8000 API Revamp estimator tags by adrinjalali · Pull Request #29677 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

API Revamp estimator tags #29677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 63 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
3ebf1c3
ENH Uses __sklearn_tags__ for tags instead of mro walking
thomasjpfan Feb 24, 2022
b06fd10
DOC Adds whats new
thomasjpfan Feb 24, 2022
e68f7f0
CI Fix assign/unassign CI
thomasjpfan Jan 2, 2024
a5e9560
CI Fix assign/unassign CI
thomasjpfan Jan 2, 2024
bb277da
Merge remote-tracking branch 'origin/main' into tags_redesign
thomasjpfan Apr 29, 2024
3ae8506
Update new estimators with __sklearn_tags__
thomasjpfan Apr 29, 2024
d088cac
Merge remote-tracking branch 'upstream/main' into tags_redesign
thomasjpfan Apr 29, 2024
dc59eef
Update to fix failing tests
thomasjpfan Apr 29, 2024
b4deda2
STY Lint
thomasjpfan Apr 29, 2024
87c113b
Fixes failing tests
thomasjpfan Apr 30, 2024
797dddf
Add dataclasses
adrinjalali Aug 13, 2024
3175a82
Merge remote-tracking branch 'upstream/main' into tags_redesign
adrinjalali Aug 13, 2024
c596d08
move to 1.6
adrinjalali Aug 13, 2024
f2ae973
Merge remote-tracking branch 'upstream/main' into tags_redesign
adrinjalali Aug 13, 2024
004d5a6
simplify bagging __sklearn_tags__
adrinjalali Aug 14, 2024
4edd1da
RFE allow nan handling
adrinjalali Aug 14, 2024
813a8c4
test fixes
adrinjalali Aug 14, 2024
87e3f25
merged with tags_redesign
adrinjalali Aug 14, 2024
d073f1a
progress
adrinjalali Aug 15, 2024
bdac769
remove old tags
adrinjalali Aug 15, 2024
dcf6051
more fixes
adrinjalali Aug 15, 2024
d93155b
tune more tests
adrinjalali Aug 15, 2024
41ed204
...
adrinjalali Aug 15, 2024
10519e8
Merge remote-tracking branch 'upstream/main' into estimator-tags
adrinjalali Aug 15, 2024
177d763
a lot more estimators
adrinjalali Aug 19, 2024
8914bf0
...
adrinjalali Aug 19, 2024
d823d13
...
adrinjalali Aug 19, 2024
220f215
...
adrinjalali Aug 19, 2024
64a55a4
rename back test name
adrinjalali Aug 19, 2024
3be6c6e
...
adrinjalali Aug 19, 2024
042b16b
...
adrinjalali Aug 19, 2024
bff06b1
...
adrinjalali Aug 19, 2024
3285fe2
...
adrinjalali Aug 20, 2024
9dbd500
...
adrinjalali Aug 20, 2024
244c1dc
...
adrinjalali Aug 20, 2024
d42b4ba
...
adrinjalali Aug 20, 2024
6c47f58
...
adrinjalali Aug 20, 2024
14a29a9
Merge remote-tracking branch 'upstream/main' into estimator-tags
adrinjalali Aug 20, 2024
d9e7305
...
adrinjalali Aug 20, 2024
c164d71
self review
adrinjalali Aug 20, 2024
bc2f53c
codecov review
adrinjalali Aug 20, 2024
d8bb6a3
docs and API
adrinjalali Aug 20, 2024
b9dada7
changelog
adrinjalali Aug 20, 2024
79aefe0
Update sklearn/utils/_tags.py
adrinjalali Aug 23, 2024
43ee0fb
Update sklearn/utils/_tags.py
adrinjalali Aug 23, 2024
2283851
Update sklearn/utils/_tags.py
adrinjalali Aug 23, 2024
c448328
Merge remote-tracking branch 'upstream/main' into estimator-tags
adrinjalali Aug 27, 2024
87a1416
Merge branch 'estimator-tags' of github.com:adrinjalali/scikit-learn …
adrinjalali Aug 27, 2024
b740987
Omar's comments
adrinjalali Aug 27, 2024
789c52a
Update sklearn/multioutput.py
adrinjalali Aug 27, 2024
07b4bfd
8000 Update sklearn/utils/_tags.py
adrinjalali Aug 30, 2024
2076697
Merge branch 'main' into estimator-tags
adrinjalali Aug 30, 2024
635ce30
fix tab/space
adrinjalali Aug 30, 2024
241559d
Add tags to docs
adrinjalali Sep 2, 2024
398dd32
Merge remote-tracking branch 'upstream/main' into estimator-tags
adrinjalali Sep 2, 2024
9c504dd
preserves_dtype is not a list of str
adrinjalali Sep 2, 2024
99486db
add missing required change in _tags.py
adrinjalali Sep 3, 2024
c0869ab
Merge remote-tracking branch 'upstream/main' into estimator-tags
adrinjalali Sep 3, 2024
0b0865a
Merge branch 'main' into estimator-tags
glemaitre Sep 4, 2024
6ac2b7c
Most Guillaume's comments
adrinjalali Sep 4, 2024
1df3729
Merge branch 'estimator-tags' of github.com:adrinjalali/scikit-learn …
adrinjalali Sep 4, 2024
55a5623
Merge remote-tracking branch 'upstream/main' into estimator-tags
adrinjalali Sep 4, 2024
88af896
remove dtype map
adrinjalali Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,14 @@ def _get_submodule(module_name, submodule_name):
"safe_mask",
"safe_sqr",
"shuffle",
"Tags",
"InputTags",
"TargetTags",
"ClassifierTags",
"RegressorTags",
"TransformerTags",
"default_tags",
"get_tags",
],
},
{
Expand Down
163 changes: 33 additions & 130 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -523,141 +523,44 @@ Estimator Tags

The estimator tags are experimental and the API is subject to change.

Scikit-learn introduced estimator tags in version 0.21. These are annotations
of estimators that allow programmatic inspection of their capabilities, such as
sparse matrix support, supported output types and supported methods. The
estimator tags are a dictionary returned by the method ``_get_tags()``. These
.. note::

Scikit-learn introduced estimator tags in version 0.21 as a
private API and mostly used in tests. However, these tags expanded
over time and many third party developers also need to use
them. Therefore in version 1.6 the API for the tags were revamped
and exposed as public API.

The estimator tags are annotations of estimators that allow
programmatic inspection of their capabilities, such as sparse matrix
support, supported output types and supported methods. The estimator
tags are an instance of :class:`~sklearn.utils.Tags` returned by the
method :meth:`~sklearn.base.BaseEstimator.__sklearn_tags__()`. These
tags are used in the common checks run by the
:func:`~sklearn.utils.estimator_checks.check_estimator` function and the
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` decorator.
Tags determine which checks to run and what input data is appropriate. Tags
can depend on estimator parameters or even system architecture and can in
general only be determined at runtime.

The current set of estimator tags are:

allow_nan (default=False)
whether the estimator supports data with missing values encoded as np.nan

array_api_support (default=False)
whether the estimator supports Array API compatible inputs.

binary_only (default=False)
whether estimator supports binary classification but lacks multi-class
classification support.

multilabel (default=False)
whether the estimator supports multilabel output

multioutput (default=False)
whether a regressor supports multi-target outputs or a classifier supports
multi-class multi-output.

multioutput_only (default=False)
whether estimator supports only multi-output classification or regression.

no_validation (default=False)
whether the estimator skips input-validation. This is only meant for
stateless and dummy transformers!

non_deterministic (default=False)
whether the estimator is not deterministic given a fixed ``random_state``

pairwise (default=False)
This boolean attribute indicates whether the data (`X`) :term:`fit` and
similar methods consists of pairwise measures over samples rather than a
feature representation for each sample. It is usually `True` where an
estimator has a `metric` or `affinity` or `kernel` parameter with value
'precomputed'. Its primary purpose is to support a :term:`meta-estimator`
or a cross validation procedure that extracts a sub-sample of data intended
for a pairwise estimator, where the data needs to be indexed on both axes.
Specifically, this tag is used by
`sklearn.utils.metaestimators._safe_split` to slice rows and
columns.

preserves_dtype (default=``[np.float64]``)
applies only on transformers. It corresponds to the data types which will
be preserved such that `X_trans.dtype` is the same as `X.dtype` after
calling `transformer.transform(X)`. If this list is empty, then the
transformer is not expected to preserve the data type. The first value in
the list is considered as the default data type, corresponding to the data
type of the output when the input data type is not going to be preserved.

poor_score (default=False)
whether the estimator fails to provide a "reasonable" test-set score, which
currently for regression is an R2 of 0.5 on ``make_regression(n_samples=200,
n_features=10, n_informative=1, bias=5.0, noise=20, random_state=42)``, and
for classification an accuracy of 0.83 on
``make_blobs(n_samples=300, random_state=0)``. These datasets and values
are based on current estimators in sklearn and might be replaced by
something more systematic.

requires_fit (default=True)
whether the estimator requires to be fitted before calling one of
`transform`, `predict`, `predict_proba`, or `decision_function`.

requires_positive_X (default=False)
whether the estimator requires positive X.

requires_y (default=False)
whether the estimator requires y to be passed to `fit`, `fit_predict` or
`fit_transform` methods. The tag is True for estimators inheriting from
`~sklearn.base.RegressorMixin` and `~sklearn.base.ClassifierMixin`.

requires_positive_y (default=False)
whether the estimator requires a positive y (only applicable for regression).

_skip_test (default=False)
whether to skip common tests entirely. Don't use this unless you have a
*very good* reason.

_xfail_checks (default=False)
dictionary ``{check_name: reason}`` of common checks that will be marked
as `XFAIL` for pytest, when using
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. These
checks will be simply ignored and not run by
:func:`~sklearn.utils.estimator_checks.check_estimator`, but a
`SkipTestWarning` will be raised.
Don't use this unless there is a *very good* reason for your estimator
not to pass the check.
Also note that the usage of this tag is highly subject to change because
we are trying to make it more flexible: be prepared for breaking changes
in the future.

stateless (default=False)
whether the estimator needs access to data for fitting. Even though an
estimator is stateless, it might still need a call to ``fit`` for
initialization.

X_types (default=['2darray'])
Supported input types for X as list of strings. Tests are currently only
run if '2darray' is contained in the list, signifying that the estimator
takes continuous 2d numpy arrays as input. The default value is
['2darray']. Other possible types are ``'string'``, ``'sparse'``,
``'categorical'``, ``dict``, ``'1dlabels'`` and ``'2dlabels'``. The goal is
that in the future the supported input type will determine the data used
during testing, in particular for ``'string'``, ``'sparse'`` and
``'categorical'`` data. For now, the test for sparse data do not make use
of the ``'sparse'`` tag.

It is unlikely that the default values for each tag will suit the needs of your
specific estimator. Additional tags can be created or default tags can be
overridden by defining a `_more_tags()` method which returns a dict with the
desired overridden tags or new tags. For example::
:func:`~sklearn.utils.estimator_checks.check_estimator` function and
the :func:`~sklearn.utils.estimator_checks.parametrize_with_checks`
decorator. Tags determine which checks to run and what input data is
appropriate. Tags can depend on estimator parameters or even system
architecture and can in general only be determined at runtime and
are therefore instance attributes rather than class attributes. See
:class:`~sklearn.utils.Tags` for more information about individual
tags.

It is unlikely that the default values for each tag will suit the
needs of your specific estimator. You can change the default values by
defining a `__sklearn_tags__()` method which returns the new values
for your estimator's tags. For example::

class MyMultiOutputEstimator(BaseEstimator):

def _more_tags(self):
return {'multioutput_only': True,
'non_deterministic': True}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.target_tags.single_output = False
tags.non_deterministic = True
return tags

Any tag that F438 is not in `_more_tags()` will just fall-back to the default values
documented above.

Even if it is not recommended, it is possible to override the method
`_get_tags()`. Note however that **all tags must be present in the dict**. If
any of the keys documented above is not present in the output of `_get_tags()`,
an error will occur.
You can create a new subclass of :class:`~sklearn.utils.Tags` if you wish
to add new tags to the existing set.

In addition to the tags, estimators also need to declare any non-optional
parameters to ``__init__`` in the ``_required_parameters`` class attribute,
Expand Down
3 changes: 1 addition & 2 deletions doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,7 @@ General Concepts
likelihoods.

estimator tags
A proposed feature (e.g. :issue:`8022`) by which the capabilities of an
estimator are described through a set of semantic tags. This would
Estimator tags describe certain capabilities of an estimator. This would
enable some runtime behaviors based on estimator inspection, but it
also allows each estimator to be tested for appropriate invariances
while being excepted from other :term:`common tests`.
Expand Down
2 changes: 1 addition & 1 deletion doc/sphinxext/allow_nan_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def make_paragraph_for_estimator_type(estimator_type):
with suppress(SkipTest):
est = _construct_instance(est_class)

if est._get_tags().get("allow_nan"):
if est.__sklearn_tags__().input_tags.allow_nan:
module_name = ".".join(est_class.__module__.split(".")[:2])
class_title = f"{est_class.__name__}"
class_url = f"./generated/{module_name}.{class_title}.html"
Expand Down
8 changes: 8 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ Version 1.6.0

**In Development**

Changes impacting many modules
------------------------------

- |Enhancement| `__sklearn_tags__` was introduced for setting tags in estimators.
More details in :ref:`estimator_tags`.
:pr:`22606` by `Thomas Fan`_ and :pr:`29677` by `Adrin Jalali`_.


Support for Array API
---------------------

Expand Down
51 changes: 19 additions & 32 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from .utils._metadata_requests import _MetadataRequester, _routing_enabled
from .utils._param_validation import validate_parameter_constraints
from .utils._set_output import _SetOutputMixin
from .utils._tags import (
_DEFAULT_TAGS,
)
from .utils._tags import default_tags
from .utils.fixes import _IS_32BIT
from .utils.validation import (
_check_feature_names_in,
Expand Down Expand Up @@ -385,19 +383,8 @@ def __setstate__(self, state):
except AttributeError:
self.__dict__.update(state)

def _more_tags(self):
return _DEFAULT_TAGS

def _get_tags(self):
collected_tags = {}
for base_class in reversed(inspect.getmro(self.__class__)):
if hasattr(base_class, "_more_tags"):
# need the if because mixins might not have _more_tags
# but might do redundant work in estimators
# (i.e. calling more tags on BaseEstimator multiple times)
more_tags = base_class._more_tags(self)
collected_tags.update(more_tags)
return collected_tags
def __sklearn_tags__(self):
return default_tags(self)

def _check_n_features(self, X, reset):
"""Set the `n_features_in_` attribute, or check against it.
Expand Down Expand Up @@ -607,7 +594,7 @@ def _validate_data(
"""
self._check_feature_names(X, reset=reset)

if y is None and self._get_tags()["requires_y"]:
if y is None and self.__sklearn_tags__().target_tags.required:
raise ValueError(
f"This {self.__class__.__name__} estimator "
"requires y to be passed, but the target y is None."
Expand Down Expand Up @@ -763,9 +750,6 @@ def score(self, X, y, sample_weight=None):

return accuracy_score(y, self.predict(X), sample_weight=sample_weight)

def _more_tags(self):
return {"requires_y": True}


class RegressorMixin:
"""Mixin class for all regression estimators in scikit-learn.
Expand Down Expand Up @@ -848,9 +832,6 @@ def score(self, X, y, sample_weight=None):
y_pred = self.predict(X)
return r2_score(y, y_pred, sample_weight=sample_weight)

def _more_tags(self):
return {"requires_y": True}


class ClusterMixin:
"""Mixin class for all cluster estimators in scikit-learn.
Expand Down Expand Up @@ -900,8 +881,11 @@ def fit_predict(self, X, y=None, **kwargs):
self.fit(X, **kwargs)
return self.labels_

def _more_tags(self):
return {"preserves_dtype": []}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
if tags.transformer_tags is not None:
tags.transformer_tags.preserves_dtype = []
return tags


class BiclusterMixin:
Expand Down Expand Up @@ -1344,18 +1328,21 @@ class MetaEstimatorMixin:
class MultiOutputMixin:
"""Mixin to mark estimators that support multioutput."""

def _more_tags(self):
return {"multioutput": True}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.target_tags.multi_output = True
return tags


class _UnstableArchMixin:
"""Mark estimators that are non-determinstic on 32bit or PowerPC"""

def _more_tags(self):
return {
"non_deterministic": _IS_32BIT
or platform.machine().startswith(("ppc", "powerpc"))
}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.non_deterministic = _IS_32BIT or platform.machine().startswith(
("ppc", "powerpc")
)
return tags


def is_classifier(estimator):
Expand Down
18 changes: 9 additions & 9 deletions sklearn/calibration.py
Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -540,16 +540,16 @@ def get_metadata_routing(self):
)
return router

def _more_tags(self):
return {
"_xfail_checks": {
"check_sample_weights_invariance": (
"Due to the cross-validation and sample ordering, removing a sample"
" is not strictly equal to putting is weight to zero. Specific unit"
" tests are added for CalibratedClassifierCV specifically."
),
}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_sample_weights_invariance": (
"Due to the cross-validation and sample ordering, removing a sample"
" is not strictly equal to putting is weight to zero. Specific unit"
" tests are added for CalibratedClassifierCV specifically."
),
}
return tags


def _fit_classifier_calibrator_pair(
Expand Down
6 changes: 4 additions & 2 deletions sklearn/cluster/_affinity_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,10 @@ def __init__(
self.affinity = affinity
self.random_state = random_state

def _more_tags(self):
return {"pairwise": self.affinity == "precomputed"}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.pairwise = self.affinity == "precomputed"
return tags

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
Expand Down
24 changes: 12 additions & 12 deletions sklearn/cluster/_bicluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,19 @@ def _k_means(self, data, n_clusters):
labels = model.labels_
return centroid, labels

def _more_tags(self):
return {
"_xfail_checks": {
"check_estimators_dtypes": "raises nan error",
"check_fit2d_1sample": "_scale_normalize fails",
"check_fit2d_1feature": "raises apply_along_axis error",
"check_estimator_sparse_matrix": "does not fail gracefully",
"check_estimator_sparse_array": "does not fail gracefully",
"check_methods_subset_invariance": "empty array passed inside",
"check_dont_overwrite_parameters": "empty array passed inside",
"check_fit2d_predict1d": "empty array passed inside",
}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_estimators_dtypes": "raises nan error",
"check_fit2d_1sample": "_scale_normalize fails",
"check_fit2d_1feature": "raises apply_along_axis error",
"check_estimator_sparse_matrix": "does not fail gracefully",
"check_estimator_sparse_array": "does not fail gracefully",
"check_methods_subset_invariance": "empty array passed inside",
"check_dont_overwrite_parameters": "empty array passed inside",
"check_fit2d_predict1d": "empty array passed inside",
}
return tags


class SpectralCoclustering(BaseSpectral):
Expand Down
Loading
0