10000 DOC improve docstring of BaseEstimator, ClassifierMixin, and Regresso… · glemaitre/scikit-learn@7a12cb9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7a12cb9

Browse files
Higgs32584glemaitre
andcommitted
DOC improve docstring of BaseEstimator, ClassifierMixin, and RegressorMixin (scikit-learn#28030)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 5b0bc65 commit 7a12cb9

File tree

1 file changed

+96
-2
lines changed

1 file changed

+96
-2
lines changed

sklearn/base.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,45 @@ def _clone_parametrized(estimator, *, safe=True):
137137
class BaseEstimator(_HTMLDocumentationLinkMixin, _MetadataRequester):
138138
"""Base class for all estimators in scikit-learn.
139139
140+
Inheriting from this class provides default implementations of:
141+
142+
- setting and getting parameters used by `GridSearchCV` and friends;
143+
- textual and HTML representation displayed in terminals and IDEs;
144+
- estimator serialization;
145+
- parameters validation;
146+
- data validation;
147+
- feature names validation.
148+
149+
Read more in the :ref:`User Guide <rolling_your_own_estimator>`.
150+
151+
140152
Notes
141153
-----
142154
All estimators should specify all the parameters that can be set
143155
at the class level in their ``__init__`` as explicit keyword
144156
arguments (no ``*args`` or ``**kwargs``).
157+
158+
Examples
159+
--------
160+
>>> import numpy as np
161+
>>> from sklearn.base import BaseEstimator
162+
>>> class MyEstimator(BaseEstimator):
163+
... def __init__(self, *, param=1):
164+
... self.param = param
165+
... def fit(self, X, y=None):
166+
... self.is_fitted_ = True
167+
... return self
168+
... def predict(self, X):
169+
... return np.full(shape=X.shape[0], fill_value=self.param)
170+
>>> estimator = MyEstimator(param=2)
171+
>>> estimator.get_params()
172+
{'param': 2}
173+
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
174+
>>> y = np.array([1, 0, 1])
175+
>>> estimator.fit(X, y).predict(X)
176+
array([2, 2, 2])
177+
>>> estimator.set_params(param=3).fit(X, y).predict(X)
178+
array([3, 3, 3])
145179
"""
146180

147181
@classmethod
@@ -652,7 +686,37 @@ def _repr_mimebundle_(self, **kwargs):
652686

653687

654688
class ClassifierMixin:
655-
"""Mixin class for all classifiers in scikit-learn."""
689+
"""Mixin class for all classifiers in scikit-learn.
690+
691+
This mixin defines the following functionality:
692+
693+
- `_estimator_type` class attribute defaulting to `"classifier"`;
694+
- `score` method that default to :func:`~sklearn.metrics.accuracy_score`.
695+
- enforce that `fit` requires `y` to be passed through the `requires_y` tag.
696+
697+
Read more in the :ref:`User Guide <rolling_your_own_estimator>`.
698+
699+
Examples
700+
--------
701+
>>> import numpy as np
702+
>>> from sklearn.base import BaseEstimator, ClassifierMixin
703+
>>> # Mixin classes should always be on the left-hand side for a correct MRO
704+
>>> class MyEstimator(ClassifierMixin, BaseEstimator):
705+
... def __init__(self, *, param=1):
706+
... self.param = param
707+
... def fit(self, X, y=None):
708+
... self.is_fitted_ = True
709+
... return self
710+
... def predict(self, X):
711+
... return np.full(shape=X.shape[0], fill_value=self.param)
712+
>>> estimator = MyEstimator(param=1)
713+
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
714+
>>> y = np.array([1, 0, 1])
715+
>>> estimator.fit(X, y).predict(X)
716+
array([1, 1, 1])
717+
>>> estimator.score(X, y)
718+
0.66...
719+
"""
656720

657721
_estimator_type = "classifier"
658722

@@ -689,7 +753,37 @@ def _more_tags(self):
689753

690754

691755
class RegressorMixin:
692-
"""Mixin class for all regression estimators in scikit-learn."""
756+
"""Mixin class for all regression estimators in scikit-learn.
757+
758+
This mixin defines the following functionality:
759+
760+
- `_estimator_type` class attribute defaulting to `"regressor"`;
761+
- `score` method that default to :func:`~sklearn.metrics.r2_score`.
762+
- enforce that `fit` requires `y` to be passed through the `requires_y` tag.
763+
764+
Read more in the :ref:`User Guide <rolling_your_own_estimator>`.
765+
766+
Examples
767+
--------
768+
>>> import numpy as np
769+
>>> from sklearn.base import BaseEstimator, RegressorMixin
770+
>>> # Mixin classes should always be on the left-hand side for a correct MRO
771+
>>> class MyEstimator(RegressorMixin, BaseEstimator):
772+
... def __init__(self, *, param=1):
773+
... self.param = param
774+
... def fit(self, X, y=None):
775+
... self.is_fitted_ = True
776+
... return self
777+
... def predict(self, X):
778+
... return np.full(shape=X.shape[0], fill_value=self.param)
779+
>>> estimator = MyEstimator(param=0)
780+
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
781+
>>> y = np.array([-1, 0, 1])
782+
>>> estimator.fit(X, y).predict(X)
783+
array([0, 0, 0])
784+
>>> estimator.score(X, y)
785+
0.0
786+
"""
693787

694788
_estimator_type = "regressor"
695789

0 commit comments

Comments
 (0)
0