8000 API: Estimator tags (#8022) · scikit-learn/scikit-learn@ab2f539 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit ab2f539

Browse files
amuellerglemaitre
authored andcommitted
API: Estimator tags (#8022)
1 parent 79b549c commit ab2f539

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+673
-301
lines changed

doc/developers/contributing.rst

Lines changed: 88 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,22 +1419,18 @@ advised to maintain notes on the `GitHub wiki
14191419
Specific models
14201420
---------------
14211421

1422-
Classifiers should accept ``y`` (target) arguments to ``fit``
1423-
that are sequences (lists, arrays) of either strings or integers.
1424-
They should not assume that the class labels
1425-
are a contiguous range of integers;
1426-
instead, they should store a list of classes
1427-
in a ``classes_`` attribute or property.
1428-
The order of class labels in this attribute
1429-
should match the order in which ``predict_proba``, ``predict_log_proba``
1430-
and ``decision_function`` return their values.
1431-
The easiest way to achieve this is to put::
1422+
Classifiers should accept ``y`` (target) arguments to ``fit`` that are
1423+
sequences (lists, arrays) of either strings or integers. They should not
1424+
assume that the class labels are a contiguous range of integers; instead, they
1425+
should store a list of classes in a ``classes_`` attribute or property. The
1426+
order of class labels in this attribute should match the order in which
1427+
``predict_proba``, ``predict_log_proba`` and ``decision_function`` return their
1428+
values. The easiest way to achieve this is to put::
14321429

14331430
self.classes_, y = np.unique(y, return_inverse=True)
14341431

1435-
in ``fit``.
1436-
This returns a new ``y`` that contains class indexes, rather than labels,
1437-
in the range [0, ``n_classes``).
1432+
in ``fit``. This returns a new ``y`` that contains class indexes, rather than
1433+
labels, in the range [0, ``n_classes``).
14381434

14391435
A classifier's ``predict`` method should return
14401436
arrays containing class labels from ``classes_``.
@@ -1445,14 +1441,89 @@ this can be achieved with::
14451441
D = self.decision_function(X)
14461442
return self.classes_[np.argmax(D, axis=1)]
14471443

1448-
In linear models, coefficients are stored in an array called ``coef_``,
1449-
and the independent term is stored in ``intercept_``.
1450-
``sklearn.linear_model.base`` contains a few base classes and mixins
1451-
that implement common linear model patterns.
1444+
In linear models, coefficients are stored in an array called ``coef_``, and the
1445+
independent term is stored in ``intercept_``. ``sklearn.linear_model.base``
1446+
contains a few base classes and mixins that implement common linear model
1447+
patterns.
14521448

14531449
The :mod:`sklearn.utils.multiclass` module contains useful functions
14541450
for working with multiclass and multilabel problems.
14551451

1452+
Estimator Tags
1453+
--------------
1454+
.. warning::
1455+
1456+
The estimator tags are experimental and the API is subject to change.
1457+
1458+
Scikit-learn introduced estimator tags in version 0.21. These are annotations
1459+
of estimators that allow programmatic inspection of their capabilities, such as
1460+
sparse matrix support, supported output types and supported methods. The
1461+
estimator tags are a dictionary returned by the method ``_get_tags()``. These
1462+
tags are used by the common tests and the :func:`sklearn.utils.estomator_checks.check_estimator` function to
1463+
decide what tests to run and what input data is appropriate. Tags can depends on
1464+
estimator parameters or even system architecture and can in general only be
1465+
determined at runtime.
1466+
1467+
The default value of all tags except for ``X_types`` is ``False``.
1468+
1469+
The current set of estimator tags are:
1470+
1471+
non_deterministic
1472+
whether the estimator is not deterministic given a fixed ``random_state``
1473+
1474+
requires_positive_data - unused for now
1475+
whether the estimator requires positive X.
1476+
1477+
no_validation
1478+
whether the estimator skips input-validation. This is only meant for stateless and dummy transformers!
1479+
1480+
multioutput - unused for now
1481+
whether a regressor supports multi-target outputs or a classifier supports multi-class multi-output.
1482+
1483+
multilabel
1484+
whether the estimator supports multilabel output
1485+
1486+
stateless
1487+
whether the estimator needs access to data for fitting. Even though
1488+
an estimator is stateless, it might still need a call to ``fit`` for initialization.
1489+
1490+
allow_nan
1491+
whether the estimator supports data with missing values encoded as np.NaN
1492+
1493+
poor_score
1494+
whether the estimator fails to provide a "reasonable" test-set score, which
1495+
currently for regression is an R2 of 0.5 on a subset of the boston housing
1496+
dataset, and for classification an accuracy of 0.83 on
1497+
``make_blobs(n_samples=300, random_state=0)``. These datasets and values
1498+
are based on current estimators in sklearn and might be replaced by
1499+
something more systematic.
1 10670 500+
1501+
multioutput_only
1502+
whether estimator supports only multi-output classification or regression.
1503+
1504+
_skip_test
1505+
whether to skip common tests entirely. Don't use this unless you have a *very good* reason.
1506+
1507+
X_types
1508+
Supported input types for X as list of strings. Tests are currently only run if '2darray' is contained
1509+
in the list, signifying that the estimator takes continuous 2d numpy arrays as input. The default
1510+
value is ['2darray']. Other possible types are ``'string'``, ``'sparse'``,
1511+
``'categorical'``, ``dict``, ``'1dlabels'`` and ``'2dlabels'``.
1512+
The goals is that in the future the supported input type will determine the
1513+
data used during testsing, in particular for ``'string'``, ``'sparse'`` and
1514+
``'categorical'`` data. For now, the test for sparse data do not make use
1515+
of the ``'sparse'`` tag.
1516+
1517+
1518+
In addition to the tags, estimators are also need to declare any non-optional
1519+
parameters to ``__init__`` in the ``_required_parameters`` class attribute,
1520+
which is a list or tuple. If ``_required_parameters`` is only
1521+
``["estimator"]`` or ``["base_estimator"]``, then the estimator will be
1522+
instantiated with an instance of ``LinearDiscriminantAnalysis`` (or
1523+
``RidgeRegression`` if the estimator is a regressor) in the tests. The choice
1524+
of these two models is somewhat idiosyncratic but both should provide robust
1525+
closed-form solutions.
1526+
14561527
.. _reading-code:
14571528

14581529
Reading the existing code base

sklearn/base.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,25 @@
66
import copy
77
import warnings
88
from collections import defaultdict
9-
from inspect import signature
9+
import struct
10+
import inspect
1011

1112
import numpy as np
1213

1314
from . import __version__
1415

16+
_DEFAULT_TAGS = {
17+
'non_deterministic': False,
18+
'requires_positive_data': False,
19+
'X_types': ['2darray'],
20+
'poor_score': False,
21+
'no_validation': False,
22+
'multioutput': False,
23+
"allow_nan": False,
24+
'stateless': False,
25+
'multilabel': False,
26+
'_skip_test': False,
27+
'multioutput_only': False}
1528

1629

1730
def clone(estimator, safe=True):
@@ -61,7 +74,6 @@ def clone(estimator, safe=True):
6174
return new_object
6275

6376

64-
###############################################################################
6577
def _pprint(params, offset=0, printer=repr):
6678
"""Pretty print the dictionary 'params'
6779
@@ -112,7 +124,17 @@ def _pprint(params, offset=0, printer=repr):
112124
return lines
113125

114126

115-
###############################################################################
127+
def _update_if_consistent(dict1, dict2):
128+
common_keys = set(dict1.keys()).intersection(dict2.keys())
129+
for key in common_keys:
130+
if dict1[key] != dict2[key]:
131+
raise TypeError("Inconsistent values for tag {}: {} != {}".format(
132+
key, dict1[key], dict2[key]
133+
))
134+
dict1.update(dict2)
135+
return dict1
136+
137+
116138
class BaseEstimator:
117139
"""Base class for all estimators in scikit-learn
118140
@@ -135,7 +157,7 @@ def _get_param_names(cls):
135157

136158
# introspect the constructor arguments to find the model parameters
137159
# to represent
138-
init_signature = signature(init)
160+
init_signature = inspect.signature(init)
139161
# Consider the constructor parameters excluding 'self'
140162
parameters = [p for p in init_signature.parameters.values()
141163
if p.name != 'self' and p.kind != p.VAR_KEYWORD]
@@ -255,8 +277,22 @@ def __setstate__(self, state):
255277
except AttributeError:
256278
self.__dict__.update(state)
257279

280+
def _get_tags(self):
281+
collected_tags = {}
282+
for base_class in inspect.getmro(self.__class__):
283+
if (hasattr(base_class, '_more_tags')
284+
and base_class != self.__class__):
285+
more_tags = base_class._more_tags(self)
286+
collected_tags = _update_if_consistent(collected_tags,
287+
more_tags)
288+
if hasattr(self, '_more_tags'):
289+
more_tags = self._more_tags()
290+
collected_tags = _update_if_consistent(collected_tags, more_tags)
291+
tags = _DEFAULT_TAGS.copy()
292+
tags.update(collected_tags)
293+
return tags
294+
258295

259-
###############################################################################
260296
class ClassifierMixin:
261297
"""Mixin class for all classifiers in scikit-learn."""
262298
_estimator_type = "classifier"
@@ -289,7 +325,6 @@ def score(self, X, y, sample_weight=None):
289325
return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
290326

291327

292-
###############################################################################
293328
class RegressorMixin:
294329
"""Mixin class for all regression estimators in scikit-learn."""
295330
_estimator_type = "regressor"
@@ -330,7 +365,6 @@ def score(self, X, y, sample_weight=None):
330365
multioutput='variance_weighted')
331366

332367

333-
###############################################################################
334368
class ClusterMixin:
335369
"""Mixin class for all cluster estimators in scikit-learn."""
336370
_estimator_type = "clusterer"
@@ -432,7 +466,6 @@ def get_submatrix(self, i, data):
432466
return data[row_ind[:, np.newaxis], col_ind]
433467

434468

435-
###############################################################################
436469
class TransformerMixin:
437470
"""Mixin class for all transformers in scikit-learn."""
438471

@@ -510,13 +543,27 @@ def fit_predict(self, X, y=None):
510543
return self.fit(X).predict(X)
511544

512545

513-
###############################################################################
514546
class MetaEstimatorMixin:
547+
_required_parameters = ["estimator"]
515548
"""Mixin class for all meta estimators in scikit-learn."""
516-
# this is just a tag for the moment
517549

518550

519-
###############################################################################
551+
class MultiOutputMixin(object):
552+
"""Mixin to mark estimators that support multioutput."""
553+
def _more_tags(self):
554+
return {'multioutput': True}
555+
556+
557+
def _is_32bit():
558+
"""Detect if process is 32bit Python."""
559+
return struct.calcsize('P') * 8 == 32
560+
561+
562+
class _UnstableOn32BitMixin(object):
563+
"""Mark estimators that are non-determinstic on 32bit."""
564+
def _more_tags(self):
565+
return {'non_deterministic': _is_32bit()}
566+
520567

521568
def is_classifier(estimator):
522569
"""Returns True if the given estimator is (probably) a classifier.

sklearn/compose/_column_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
158158
[0.5, 0.5, 0. , 1. ]])
159159
160160
"""
161+
_required_parameters = ['transformers']
161162

162163
def __init__(self, transformers, remainder='drop', sparse_threshold=0.3,
163164
n_jobs=None, transformer_weights=None):

sklearn/compose/_target.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,6 @@ def predict(self, X):
233233
pred_trans = pred_trans.squeeze(axis=1)
234234

235235
return pred_trans
236+
237+
def _more_tags(self):
238+
return {'poor_score': True, 'no_validation': True}

sklearn/cross_decomposition/cca_.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .pls_ import _PLS
2+
from ..base import _UnstableOn32BitMixin
23

34
__all__ = ['CCA']
45

56

6-
class CCA(_PLS):
7+
class CCA(_PLS, _UnstableOn32BitMixin):
78
"""CCA Canonical Correlation Analysis.
89
910
CCA inherits from PLS with mode="B" and deflation_mode="canonical".

sklearn/cross_decomposition/pls_.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from scipy.sparse.linalg import svds
1414

1515
from ..base import BaseEstimator, RegressorMixin, TransformerMixin
16+
from ..base import MultiOutputMixin
1617
from ..utils import check_array, check_consistent_length
1718
from ..utils.extmath import svd_flip
1819
from ..utils.validation import check_is_fitted, FLOAT_DTYPES
@@ -116,7 +117,7 @@ def _center_scale_xy(X, Y, scale=True):
116117
return X, Y, x_mean, y_mean, x_std, y_std
117118

118119

119-
class _PLS(BaseEstimator, TransformerMixin, RegressorMixin,
120+
class _PLS(BaseEstimator, TransformerMixin, RegressorMixin, MultiOutputMixin,
120121
metaclass=ABCMeta):
121122
"""Partial Least Squares (PLS)
122123
@@ -454,6 +455,9 @@ def fit_transform(self, X, y=None):
454455
"""
455456
return self.fit(X, y).transform(X, y)
456457

458+
def _more_tags(self):
459+
return {'poor_score': True}
460+
457461

458462
class PLSRegression(_PLS):
459463
"""PLS regression

sklearn/decomposition/kernel_pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from ..utils import check_random_state
1111
from ..utils.validation import check_is_fitted, check_array
1212
from ..exceptions import NotFittedError
13-
from ..base import BaseEstimator, TransformerMixin
13+
from ..base import BaseEstimator, TransformerMixin, _UnstableOn32BitMixin
1414
from ..preprocessing import KernelCenterer
1515
from ..metrics.pairwise import pairwise_kernels
1616

1717

18-
class KernelPCA(BaseEstimator, TransformerMixin):
18+
class KernelPCA(BaseEstimator, TransformerMixin, _UnstableOn32BitMixin):
1919
"""Kernel Principal component analysis (KPCA)
2020
2121
Non-linear dimensionality reduction through the use of kernels (see

sklearn/decomposition/truncated_svd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def fit_transform(self, X, y=None):
156156
X_new : array, shape (n_samples, n_components)
157157
Reduced version of X. This will always be a dense array.
158158
"""
159-
X = check_array(X, accept_sparse=['csr', 'csc'])
159+
X = check_array(X, accept_sparse=['csr', 'csc'],
160+
ensure_min_features=2)
160161
random_state = check_random_state(self.random_state)
161162

162163
if self.algorithm == "arpack":

0 commit comments

Comments
 (0)
0