8000 ENH ensure that a warning is raised when sample_weight is not supported · scikit-learn/scikit-learn@70d49de · GitHub
[go: up one dir, main page]

Skip to content

Commit 70d49de

Browse files
committed
ENH ensure that a warning is raised when sample_weight is not supported
Also: make it possible to fix the random_state used to break the jitter for isotonic calibration.
1 parent eee0e67 commit 70d49de

File tree

2 files changed

+78
-19
lines changed

2 files changed

+78
-19
lines changed

sklearn/calibration.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from __future__ import division
1111
import inspect
12+
import warnings
1213

1314
from math import log
1415
import numpy as np
@@ -17,10 +18,11 @@
1718

1819
from .base import BaseEstimator, ClassifierMixin, RegressorMixin, clone
1920
from .preprocessing import LabelBinarizer
21+
from .utils import check_random_state
2022
from .utils import check_X_y, check_array, indexable, column_or_1d
2123
from .utils.validation import check_is_fitted
2224
from .isotonic import IsotonicRegression
23-
from .naive_bayes import GaussianNB
25+
from .svm import LinearSVC
2426
from .cross_validation import _check_cv
2527
from .metrics.classification import _check_binary_probabilistic_predictions
2628

@@ -57,6 +59,9 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin):
5759
If "prefit" is passed, it is assumed that base_estimator has been
5860
fitted already and all data is used for calibration.
5961
62+
random_state : int, RandomState instance or None (default=None)
63+
Used to randomly break ties when method is 'isotonic'.
64+
6065
Attributes
6166
----------
6267
classes_ : array, shape (n_classes)
@@ -81,10 +86,12 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin):
8186
.. [4] Predicting Good Probabilities with Supervised Learning,
8287
A. Niculescu-Mizil & R. Caruana, ICML 2005
8388
"""
84-
def __init__(self, base_estimator=GaussianNB(), method='sigmoid', cv=3):
89+
def __init__(self, base_estimator=None, method='sigmoid', cv=3,
90+
random_state=None):
8591
self.base_estimator = base_estimator
8692
self.method = method
8793
self.cv = cv
94+
self.random_state = random_state
8895

8996
def fit(self, X, y, sample_weight=None):
9097
"""Fit the calibrated model
@@ -109,6 +116,7 @@ def fit(self, X, y, sample_weight=None):
109116
X, y = indexable(X, y)
110117
lb = LabelBinarizer().fit(y)
111118
self.classes_ = lb.classes_
119+
random_state = check_random_state(self.random_state)
112120

113121
# Check that we each cross-validation fold can have at least one
114122
# example per class
@@ -121,28 +129,43 @@ def fit(self, X, y, sample_weight=None):
121129
% (n_folds, n_folds))
122130

123131
self.calibrated_classifiers_ = []
132+
if self.base_estimator is None:
133+
base_estimator = LinearSVC()
134+
else:
135+
base_estimator = self.base_estimator
136+
124137
if self.cv == "prefit":
125-
calibrated_classifier = _CalibratedClassifier(self.base_estimator,
126-
method=self.method)
138+
calibrated_classifier = _CalibratedClassifier(
139+
base_estimator, method=self.method, random_state=random_state)
127140
if sample_weight is not None:
128141
calibrated_classifier.fit(X, y, sample_weight)
129142
else:
130143
calibrated_classifier.fit(X, y)
131144
self.calibrated_classifiers_.append(calibrated_classifier)
132145
else:
133146
cv = _check_cv(self.cv, X, y, classifier=True)
147+
arg_names = inspect.getargspec(base_estimator.fit)[0]
148+
estimator_name = type(base_estimator).__name__
149+
if (sample_weight is not None
150+
and "sample_weight" not in arg_names):
151+
warnings.warn("%s does not support sample_weight. Samples"
152+
" weights are only used for the calibration"
153+
" itself." % estimator_name)
154+
base_estimator_sample_weight = None
155+
else:
156+
base_estimator_sample_weight = sample_weight
134157
for train, test in cv:
135-
this_estimator = clone(self.base_estimator)
136-
if sample_weight is not None and \
137-
"sample_weight" in inspect.getargspec(
138-
this_estimator.fit)[0]:
139-
this_estimator.fit(X[train], y[train],
140-
sample_weight[train])
158+
this_estimator = clone(base_estimator)
159+
if base_estimator_sample_weight is not None:
160+
this_estimator.fit(
161+
X[train], y[train],
162+
sample_weight=base_estimator_sample_weight[train])
141163
else:
142164
this_estimator.fit(X[train], y[train])
143165

144-
calibrated_classifier = \
145-
_CalibratedClassifier(this_estimator, method=self.method)
166+
calibrated_classifier = _CalibratedClassifier(
167+
this_estimator, method=self.method,
168+
random_state=random_state)
146169
if sample_weight is not None:
147170
calibrated_classifier.fit(X[test], y[test],
148171
sample_weight[test])
@@ -219,6 +242,9 @@ class _CalibratedClassifier(object):
219242
corresponds to Platt's method or 'isotonic' which is a
220243
non-parameteric approach based on isotonic regression.
221244
245+
random_state : int, RandomState instance or None (default=None)
246+
Used to randomly break ties when method is 'isotonic'.
247+
222248
References
223249
----------
224250
.. [1] Obtaining calibrated probability estimates from decision trees
@@ -233,9 +259,11 @@ class _CalibratedClassifier(object):
233259
.. [4] Predicting Good Probabilities with Supervised Learning,
234260
A. Niculescu-Mizil & R. Caruana, ICML 2005
235261
"""
236-
def __init__(self, base_estimator, method='sigmoid'):
262+
def __init__(self, base_estimator, method='sigmoid',
263+
random_state=None):
237264
self.base_estimator = base_estimator
238265
self.method = method
266+
self.random_state = random_state
239267

240268
def _preproc(self, X):
241269
n_classes = len(self.classes_)
@@ -289,8 +317,8 @@ def fit(self, X, y, sample_weight=None):
289317
# have different outputs. Since this is not untypical
290318
# when calibrating, we add some small random jitter to
291319
# the inputs.
292-
this_df = \
293-
this_df + np.random.normal(0, 1e-10, this_df.shape[0])
320+
jitter = self.random_state.normal(0, 1e-10, this_df.shape[0])
321+
this_df = this_df + jitter
294322
elif self.method == 'sigmoid':
295323
calibrator = _SigmoidCalibration()
296324
else:

sklearn/tests/test_calibration.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
assert_greater, assert_almost_equal,
99
assert_greater_equal,
1010
assert_array_equal,
11-
assert_raises)
11+
assert_raises,
12+
assert_warns_message)
1213
from sklearn.datasets import make_classification, make_blobs
1314
from sklearn.naive_bayes import MultinomialNB
1415
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
@@ -35,8 +36,7 @@ def test_calibration():
3536
X_test, y_test = X[n_samples:], y[n_samples:]
3637

3738
# Naive-Bayes
38-
clf = MultinomialNB()
39-
clf.fit(X_train, y_train, sw_train)
39+
clf = MultinomialNB().fit(X_train, y_train, sample_weight=sw_train)
4040
prob_pos_clf = clf.predict_proba(X_test)[:, 1]
4141

4242
pc_clf = CalibratedClassifierCV(clf, cv=y.size + 1)
@@ -86,7 +86,7 @@ def test_calibration():
8686

8787
# check that calibration can also deal with regressors that have
8888
# a decision_function
89-
clf_base_regressor = CalibratedClassifierCV(Ridge(), method="sigmoid")
89+
clf_base_regressor = CalibratedClassifierCV(Ridge())
9090
clf_base_regressor.fit(X_train, y_train)
9191
clf_base_regressor.predict(X_test)
9292

@@ -102,6 +102,37 @@ def test_calibration():
102102
assert_raises(RuntimeError, clf_base_regressor.fit, X_train, y_train)
103103

104104

105+
def test_sample_weight_warning():
106+
n_samples = 100
107+
X, y = make_classification(n_samples=2 * n_samples, n_features=6,
108+
random_state=42)
109+
110+
sample_weight = np.random.RandomState(seed=42).uniform(size=len(y))
111+
X_train, y_train, sw_train = \
112+
X[:n_samples], y[:n_samples], sample_weight[:n_samples]
113+
X_test = X[n_samples:]
114+
115+
for method in ['sigmoid', 'isotonic']:
116+
base_estimator = LinearSVC(random_state=42)
117+
calibrated_clf = CalibratedClassifierCV(base_estimator, method=method,
118+
random_state=42)
119+
# LinearSVC does not currently support sample weights but they
120+
# can still be used for the calibration step (with a warning)
121+
msg = "LinearSVC does not support sample_weight."
122+
assert_warns_message(
123+
UserWarning, msg,
124+
calibrated_clf.fit, X_train, y_train, sample_weight=sw_train)
125+
probs_with_sw = calibrated_clf.predict_proba(X_test)
126+
127+
# As the weights are used for the calibration, they should still yield
128+
# a different predictions
129+
calibrated_clf.fit(X_train, y_train)
130+
probs_without_sw = calibrated_clf.predict_proba(X_test)
131+
132+
diff = np.linalg.norm(probs_with_sw - probs_without_sw)
133+
assert_greater(diff, 0.1)
134+
135+
105136
def test_calibration_multiclass():
106137
"""Test calibration for multiclass """
107138
# test multi-class setting with classifier that implements

0 commit comments

Comments
 (0)
0