8000 Use base.is_classifier instead instead of isinstance · scikit-learn/scikit-learn@35c1037 · GitHub
[go: up one dir, main page]

Skip to content

Commit 35c1037

Browse files
minghui-liulesteve
authored andcommitted
Use base.is_classifier instead instead of isinstance
1 parent 1674412 commit 35c1037

File tree

6 files changed

+16
-16
lines changed

6 files changed

+16
-16
lines changed

sklearn/ensemble/weight_boosting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from numpy.core.umath_tests import inner1d
3030

3131
from .base import BaseEnsemble
32-
from ..base import ClassifierMixin, RegressorMixin, is_regressor
32+
from ..base import ClassifierMixin, RegressorMixin, is_regressor, is_classifier
3333
from ..externals import six
3434
from ..externals.six.moves import zip
3535
from ..externals.six.moves import xrange as range
@@ -231,7 +231,7 @@ def staged_score(self, X, y, sample_weight=None):
231231
z : float
232232
"""
233233
for y_pred in self.staged_predict(X):
234-
if isinstance(self, ClassifierMixin):
234+
if is_classifier(self):
235235
yield accuracy_score(y, y_pred, sample_weight=sample_weight)
236236
else:
237237
yield r2_score(y, y_pred, sample_weight=sample_weight)

sklearn/multioutput.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import scipy.sparse as sp
1919
from abc import ABCMeta, abstractmethod
2020
from .base import BaseEstimator, clone, MetaEstimatorMixin
21-
from .base import RegressorMixin, ClassifierMixin
21+
from .base import RegressorMixin, ClassifierMixin, is_classifier
2222
from .model_selection import cross_val_predict
2323
from .utils import check_array, check_X_y, check_random_state
2424
from .utils.fixes import parallel_helper
@@ -152,7 +152,7 @@ def fit(self, X, y, sample_weight=None):
152152
multi_output=True,
153153
accept_sparse=True)
154154

155-
if isinstance(self, ClassifierMixin):
155+
if is_classifier(self):
156156
check_classification_targets(y)
157157

158158
if y.ndim == 1:

sklearn/neural_network/multilayer_perceptron.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import warnings
1414

1515
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
16+
from ..base import is_classifier
1617
from ._base import ACTIVATIONS, DERIVATIVES, LOSS_FUNCTIONS
1718
from ._stochastic_optimizers import SGDOptimizer, AdamOptimizer
1819
from ..model_selection import train_test_split
@@ -268,7 +269,7 @@ def _initialize(self, y, layer_units):
268269
self.n_layers_ = len(layer_units)
269270

270271
# Output for regression
271-
if not isinstance(self, ClassifierMixin):
272+
if not is_classifier(self):
272273
self.out_activation_ = 'identity'
273274
# Output for multi class
274275
elif self._label_binarizer.y_type_ == 'multiclass':
@@ -491,7 +492,7 @@ def _fit_stochastic(self, X, y, activations, deltas, coef_grads,
491492
X, X_val, y, y_val = train_test_split(
492493
X, y, random_state=self._random_state,
493494
test_size=self.validation_fraction)
494-
if isinstance(self, ClassifierMixin):
495+
if is_classifier(self):
495496
y_val = self._label_binarizer.inverse_transform(y_val)
496497
else:
497498
X_val = None

sklearn/tree/tests/test_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from numpy.random import RandomState
88

9-
from sklearn.base import ClassifierMixin
9+
from sklearn.base import is_classifier
1010
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
1111
from sklearn.ensemble import GradientBoostingClassifier
1212
from sklearn.tree import export_graphviz
@@ -292,7 +292,7 @@ def test_precision():
292292
len(search("\.\d+", finding.group()).group()),
293293
precision + 1)
294294
# check impurity
295-
if isinstance(clf, ClassifierMixin):
295+
if is_classifier(clf):
296296
pattern = "gini = \d+\.\d+"
297297
else:
298298
pattern = "friedman_mse = \d+\.\d+"

sklearn/tree/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def fit(self, X, y, sample_weight=None, check_input=True,
123123

124124
# Determine output settings
125125
n_samples, self.n_features_ = X.shape
126-
is_classification = isinstance(self, ClassifierMixin)
126+
is_classification = is_classifier(self)
127127

128128
y = np.atleast_1d(y)
129129
expanded_class_weight = None
@@ -413,7 +413,7 @@ def predict(self, X, check_input=True):
413413
n_samples = X.shape[0]
414414

415415
# Classification
416-
if isinstance(self, ClassifierMixin):
416+
if is_classifier(self):
417417
if self.n_outputs_ == 1:
418418
return self.classes_.take(np.argmax(proba, axis=1), axis=0)
419419

sklearn/utils/estimator_checks.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from __future__ import print_function
2-
32
import types
43
import warnings
54
import sys
@@ -35,8 +34,8 @@
3534
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
3635

3736

38-
from sklearn.base import (clone, ClassifierMixin, RegressorMixin,
39-
TransformerMixin, ClusterMixin, BaseEstimator)
37+
from sklearn.base import (clone, TransformerMixin, ClusterMixin,
38+
BaseEstimator, is_classifier, is_regressor)
4039
from sklearn.metrics import accuracy_score, adjusted_rand_score, f1_score
4140

4241
from sklearn.random_projection import BaseRandomProjection
@@ -208,10 +207,10 @@ def _yield_clustering_checks(name, clusterer):
208207
def _yield_all_checks(name, estimator):
209208
for check in _yield_non_meta_checks(name, estimator):
210209
yield check
211-
if isinstance(estimator, ClassifierMixin):
210+
if is_classifier(estimator):
212211
for check in _yield_classifier_checks(name, estimator):
213212
yield check
214-
if isinstance(estimator, RegressorMixin):
213+
if is_regressor(estimator):
215214
for check in _yield_regressor_checks(name, estimator):
216215
yield check
217216
if isinstance(estimator, TransformerMixin):
@@ -980,7 +979,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
980979
X -= X.min()
981980

982981
try:
983-
if isinstance(estimator, ClassifierMixin):
982+
if is_classifier(estimator):
984983
classes = np.unique(y)
985984
estimator.partial_fit(X, y, classes=classes)
986985
else:

0 commit comments

Comments
 (0)
0