|
9 | 9 |
|
10 | 10 | import numpy as np
|
11 | 11 |
|
12 |
| -from ..base import BaseEstimator |
| 12 | +from ..base import BaseEstimator, clone |
13 | 13 |
|
14 | 14 |
|
15 | 15 | class ClassifierChain(BaseEstimator):
|
@@ -44,9 +44,11 @@ class ClassifierChain(BaseEstimator):
|
44 | 44 | >>> from sklearn.multi_label import ClassifierChain
|
45 | 45 | >>> from sklearn.svm import LinearSVC
|
46 | 46 | >>> X, Y = make_multilabel_classification(return_indicator=True, random_state=0)
|
47 |
| - >>> cc = ClassifierChain(base_estimator=LinearSVC) |
48 |
| - >>> cc.fit(X, Y) |
49 |
| - ClassifierChain(base_estimator=<class 'sklearn.svm.classes.LinearSVC'>) |
| 47 | + >>> cc = ClassifierChain(base_estimator=LinearSVC()) |
| 48 | + >>> cc.fit(X, Y) #doctest: +NORMALIZE_WHITESPACE |
| 49 | + ClassifierChain(base_estimator=LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True, |
| 50 | + intercept_scaling=1, loss='l2', max_iter=1000, multi_class='ovr', |
| 51 | + penalty='l2', random_state=None, tol=0.0001, verbose=0)) |
50 | 52 | """
|
51 | 53 |
|
52 | 54 | def __init__(self, base_estimator):
|
@@ -74,7 +76,7 @@ def fit(self, X, Y):
|
74 | 76 | for i in xrange(self.n_labels_):
|
75 | 77 | y = Y[:, i]
|
76 | 78 |
|
77 |
| - clf = self.base_estimator() |
| 79 | + clf = clone(self.base_estimator) |
78 | 80 | clf.fit(X, y)
|
79 | 81 | self.classifiers_.append(clf)
|
80 | 82 |
|
|
0 commit comments