8000 Use clone instead of accepting a estimator class · scikit-learn/scikit-learn@22fd57c · GitHub
[go: up one dir, main page]

Skip to content

Commit 22fd57c

Browse files
committed
Use clone instead of accepting a estimator class
1 parent 74d8f39 commit 22fd57c

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

sklearn/multi_label/classifier_chain.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111

12-
from ..base import BaseEstimator
12+
from ..base import BaseEstimator, clone
1313

1414

1515
class ClassifierChain(BaseEstimator):
@@ -44,9 +44,11 @@ class ClassifierChain(BaseEstimator):
4444
>>> from sklearn.multi_label import ClassifierChain
4545
>>> from sklearn.svm import LinearSVC
4646
>>> X, Y = make_multilabel_classification(return_indicator=True, random_state=0)
47-
>>> cc = ClassifierChain(base_estimator=LinearSVC)
48-
>>> cc.fit(X, Y)
49-
ClassifierChain(base_estimator=<class 'sklearn.svm.classes.LinearSVC'>)
47+
>>> cc = ClassifierChain(base_estimator=LinearSVC())
48+
>>> cc.fit(X, Y) #doctest: +NORMALIZE_WHITESPACE
49+
ClassifierChain(base_estimator=LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,
50+
intercept_scaling=1, loss='l2', max_iter=1000, multi_class='ovr',
51+
penalty='l2', random_state=None, tol=0.0001, verbose=0))
5052
"""
5153

5254
def __init__(self, base_estimator):
@@ -74,7 +76,7 @@ def fit(self, X, Y):
7476
for i in xrange(self.n_labels_):
7577
y = Y[:, i]
7678

77-
clf = self.base_estimator()
79+
clf = clone(self.base_estimator)
7880
clf.fit(X, y)
7981
self.classifiers_.append(clf)
8082

0 commit comments

Comments
 (0)
0