8000 DOC describe scikit-learn-contrib in related projects and contributin… · scikit-learn/scikit-learn@067adad · GitHub
[go: up one dir, main page]

Skip to content

Commit 067adad

Browse files
jnothmanlesteve
authored andcommitted
DOC describe scikit-learn-contrib in related projects and contributing docs (#8440)
1 parent fdb32e2 commit 067adad

File tree

2 files changed

+74
-43
lines changed

2 files changed

+74
-43
lines changed

doc/developers/contributing.rst

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -957,49 +957,73 @@ adheres to the scikit-learn interface and standards by running
957957
>>> check_estimator(LinearSVC) # passes
958958

959959
The main motivation to make a class compatible to the scikit-learn estimator
960-
interface might be that you want to use it together with model assessment and
961-
selection tools such as :class:`model_selection.GridSearchCV`.
962-
963-
For this to work, you need to implement the following interface.
964-
If a dependency on scikit-learn is okay for your code,
965-
you can prevent a lot of boilerplate code
966-
by deriving a class from ``BaseEstimator``
967-
and optionally the mixin classes in ``sklearn.base``.
968-
E.g., below is a custom classifier. For more information on this example, see
969-
`scikit-learn-contrib <https://github.com/scikit-learn-contrib/project-template/blob/master/skltemplate/template.py>`_::
970-
971-
>>> import numpy as np
972-
>>> from sklearn.base import BaseEstimator, ClassifierMixin
973-
>>> from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
974-
>>> from sklearn.utils.multiclass import unique_labels
975-
>>> from sklearn.metrics import euclidean_distances
976-
>>> class TemplateClassifier(BaseEstimator, ClassifierMixin):
977-
...
978-
... def __init__(self, demo_param='demo'):
979-
... self.demo_param = demo_param
980-
...
981-
... def fit(self, X, y):
982-
...
983-
... # Check that X and y have correct shape
984-
... X, y = check_X_y(X, y)
985-
... # Store the classes seen during fit
986-
... self.classes_ = unique_labels(y)
987-
...
988-
... self.X_ = X
989-
... self.y_ = y
990-
... # Return the classifier
991-
... return self
992-
...
993-
... def predict(self, X):
994-
...
995-
... # Check is fit had been called
996-
... check_is_fitted(self, ['X_', 'y_'])
997-
...
998-
... # Input validation
999-
... X = check_array(X)
1000-
...
1001-
... closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
1002-
... return self.y_[closest]
960+
interface might be that you want to use it together with model evaluation and
961+
selection tools such as :class:`model_selection.GridSearchCV` and
962+
:class:`pipeline.Pipeline`.
963+
964+
Before detailing the required interface below, we describe two ways to achieve
965+
the correct interface more easily.
966+
967+
.. topic:: Project template:
968+
969+
We provide a `project template <https://github.com/scikit-learn-contrib/project-template/>`_
970+
which helps in the creation of Python packages containing scikit-learn compatible estimators.
971+
It provides:
972+
973+
* an initial git repository with Python package directory structure
974+
* a template of a scikit-learn estimator
975+
* an initial test suite including use of ``check_estimator``
976+
* directory structures and scripts to compile documentation and example
977+
galleries
978+
* scripts to manage continuous integration (testing on Linux and Windows)
979+
* instructions from getting started to publishing on `PyPi <https://pypi.python.org/pypi>`_
980+
981+
.. topic:: ``BaseEstimator`` and mixins:
982+
983+
We tend to use use "duck typing", so building an estimator which follows
984+
the API suffices for compatibility, without needing to inherit from or
985+
even import any scikit-learn classes.
986+
987+
However, if a dependency on scikit-learn is acceptable in your code,
988+
you can prevent a lot of boilerplate code
989+
by deriving a class from ``BaseEstimator``
990+
and optionally the mixin classes in ``sklearn.base``.
991+
For example, below is a custom classifier, with more examples included
992+
in the scikit-learn-contrib
993+
`project template <https://github.com/scikit-learn-contrib/project-template/blob/master/skltemplate/template.py>`_.
994+
995+
>>> import numpy as np
996+
>>> from sklearn.base import BaseEstimator, ClassifierMixin
997+
>>> from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
998+
>>> from sklearn.utils.multiclass import unique_labels
999+
>>> from sklearn.metrics import euclidean_distances
1000+
>>> class TemplateClassifier(BaseEstimator, ClassifierMixin):
1001+
...
1002+
... def __init__(self, demo_param='demo'):
1003+
... self.demo_param = demo_param
1004+
...
1005+
... def fit(self, X, y):
1006+
...
1007+
... # Check that X and y have correct shape
1008+
... X, y = check_X_y(X, y)
1009+
... # Store the classes seen during fit
1010+
... self.classes_ = unique_labels(y)
1011+
...
1012+
... self.X_ = X
1013+
... self.y_ = y
1014+
... # Return the classifier
1015+
... return self
1016+
...
1017+
... def predict(self, X):
1018+
...
1019+
... # Check is fit had been called
1020+
... check_is_fitted(self, ['X_', 'y_'])
1021+
...
1022+
... # Input validation
1023+
... X = check_array(X)
1024+
...
1025+
... closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
1026+
... return self.y_[closest]
10031027

10041028

10051029
get_params and set_params

doc/related_projects.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
Related Projects
55
=====================================
66

7+
Projects implementing the scikit-learn estimator API are encouraged to use
8+
the `scikit-learn-contrib template <https://github.com/scikit-learn-contrib/project-template>`_
9+
which facilitates best practices for testing and documenting estimators.
10+
The `scikit-learn-contrib GitHub organisation <https://github.com/scikit-learn-contrib/scikit-learn-contrib>`_
11+
also accepts high-quality contributions of repositories conforming to this
12+
template.
13+
714
Below is a list of sister-projects, extensions and domain specific packages.
815

916
Interoperability and framework enhancements

0 commit comments

Comments
 (0)
0