@@ -957,49 +957,73 @@ adheres to the scikit-learn interface and standards by running
957
957
>>> check_estimator(LinearSVC) # passes
958
958
959
959
The main motivation to make a class compatible to the scikit-learn estimator
960
- interface might be that you want to use it together with model assessment and
961
- selection tools such as :class: `model_selection.GridSearchCV `.
962
-
963
- For this to work, you need to implement the following interface.
964
- If a dependency on scikit-learn is okay for your code,
965
- you can prevent a lot of boilerplate code
966
- by deriving a class from ``BaseEstimator ``
967
- and optionally the mixin classes in ``sklearn.base ``.
968
- E.g., below is a custom classifier. For more information on this example, see
969
- `scikit-learn-contrib <https://github.com/scikit-learn-contrib/project-template/blob/master/skltemplate/template.py >`_::
970
-
971
- >>> import numpy as np
972
- >>> from sklearn.base import BaseEstimator, ClassifierMixin
973
- >>> from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
974
- >>> from sklearn.utils.multiclass import unique_labels
975
- >>> from sklearn.metrics import euclidean_distances
976
- >>> class TemplateClassifier(BaseEstimator, ClassifierMixin):
977
- ...
978
- ... def __init__(self, demo_param='demo'):
979
- ... self.demo_param = demo_param
980
- ...
981
- ... def fit(self, X, y):
982
- ...
983
- ... # Check that X and y have correct shape
984
- ... X, y = check_X_y(X, y)
985
- ... # Store the classes seen during fit
986
- ... self.classes_ = unique_labels(y)
987
- ...
988
- ... self.X_ = X
989
- ... self.y_ = y
990
- ... # Return the classifier
991
- ... return self
992
- ...
993
- ... def predict(self, X):
994
- ...
995
- ... # Check is fit had been called
996
- ... check_is_fitted(self, ['X_', 'y_'])
997
- ...
998
- ... # Input validation
999
- ... X = check_array(X)
1000
- ...
1001
- ... closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
1002
- ... return self.y_[closest]
960
+ interface might be that you want to use it together with model evaluation and
961
+ selection tools such as :class: `model_selection.GridSearchCV ` and
962
+ :class: `pipeline.Pipeline `.
963
+
964
+ Before detailing the required interface below, we describe two ways to achieve
965
+ the correct interface more easily.
966
+
967
+ .. topic :: Project template:
968
+
969
+ We provide a `project template <https://github.com/scikit-learn-contrib/project-template/ >`_
970
+ which helps in the creation of Python packages containing scikit-learn compatible estimators.
971
+ It provides:
972
+
973
+ * an initial git repository with Python package directory structure
974
+ * a template of a scikit-learn estimator
975
+ * an initial test suite including use of ``check_estimator ``
976
+ * directory structures and scripts to compile documentation and example
977
+ galleries
978
+ * scripts to manage continuous integration (testing on Linux and Windows)
979
+ * instructions from getting started to publishing on `PyPi <https://pypi.python.org/pypi >`_
980
+
981
+ .. topic :: ``BaseEstimator`` and mixins:
982
+
983
+ We tend to use use "duck typing", so building an estimator which follows
984
+ the API suffices for compatibility, without needing to inherit from or
985
+ even import any scikit-learn classes.
986
+
987
+ However, if a dependency on scikit-learn is acceptable in your code,
988
+ you can prevent a lot of boilerplate code
989
+ by deriving a class from ``BaseEstimator ``
990
+ and optionally the mixin classes in ``sklearn.base ``.
991
+ For example, below is a custom classifier, with more examples included
992
+ in the scikit-learn-contrib
993
+ `project template <https://github.com/scikit-learn-contrib/project-template/blob/master/skltemplate/template.py >`_.
994
+
995
+ >>> import numpy as np
996
+ >>> from sklearn.base import BaseEstimator, ClassifierMixin
997
+ >>> from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
998
+ >>> from sklearn.utils.multiclass import unique_labels
999
+ >>> from sklearn.metrics import euclidean_distances
1000
+ >>> class TemplateClassifier (BaseEstimator , ClassifierMixin ):
1001
+ ...
1002
+ ... def __init__ (self , demo_param = ' demo' ):
1003
+ ... self .demo_param = demo_param
1004
+ ...
1005
+ ... def fit (self , X , y ):
1006
+ ...
1007
+ ... # Check that X and y have correct shape
1008
+ ... X, y = check_X_y(X, y)
1009
+ ... # Store the classes seen during fit
1010
+ ... self .classes_ = unique_labels(y)
1011
+ ...
1012
+ ... self .X_ = X
1013
+ ... self .y_ = y
1014
+ ... # Return the classifier
1015
+ ... return self
1016
+ ...
1017
+ ... def predict (self , X ):
1018
+ ...
1019
+ ... # Check is fit had been called
1020
+ ... check_is_fitted(self , [' X_' , ' y_' ])
1021
+ ...
1022
+ ... # Input validation
1023
+ ... X = check_array(X)
1024
+ ...
1025
+ ... closest = np.argmin(euclidean_distances(X, self .X_), axis = 1 )
1026
+ ... return self .y_[closest]
1003
1027
1004
1028
1005
1029
get_params and set_params
0 commit comments