From 2f69574393d822326a68d3238536d60b50a9e2fe Mon Sep 17 00:00:00 2001 From: trevorstephens Date: Mon, 8 Jun 2015 21:33:56 -0700 Subject: [PATCH] add sample_weight to RidgeClassifier --- doc/whats_new.rst | 2 ++ sklearn/linear_model/ridge.py | 16 ++++++++----- sklearn/linear_model/tests/test_ridge.py | 29 ++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 9add06a83473e..0a3a6dd054255 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -71,6 +71,8 @@ Enhancements It is now possible to ignore one or more labels, such as where a multiclass problem has a majority class to ignore. By `Joel Nothman`_. + - Add ``sample_weight`` support to :class:`linear_model.RidgeClassifier`. + By `Trevor Stephens`_. Bug fixes ......... diff --git a/sklearn/linear_model/ridge.py b/sklearn/linear_model/ridge.py index 77c3e10dbbfd9..a2bd923e9dbfe 100644 --- a/sklearn/linear_model/ridge.py +++ b/sklearn/linear_model/ridge.py @@ -572,7 +572,7 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False, copy_X=copy_X, max_iter=max_iter, tol=tol, solver=solver) self.class_weight = class_weight - def fit(self, X, y): + def fit(self, X, y, sample_weight=None): """Fit Ridge regression model. Parameters @@ -583,20 +583,24 @@ def fit(self, X, y): y : array-like, shape = [n_samples] Target values + sample_weight : float or numpy array of shape (n_samples,) + Sample weight. + Returns ------- self : returns an instance of self. """ + if sample_weight is None: + sample_weight = 1. + self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1) Y = self._label_binarizer.fit_transform(y) if not self._label_binarizer.y_type_.startswith('multilabel'): y = column_or_1d(y, warn=True) - if self.class_weight: - # get the class weight corresponding to each sample - sample_weight = compute_sample_weight(self.class_weight, y) - else: - sample_weight = None + # modify the sample weights with the corresponding class weight + sample_weight = (sample_weight * + compute_sample_weight(self.class_weight, y)) super(RidgeClassifier, self).fit(X, Y, sample_weight=sample_weight) return self diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py index daa8f8bbcc725..e95d22af9c4b2 100644 --- a/sklearn/linear_model/tests/test_ridge.py +++ b/sklearn/linear_model/tests/test_ridge.py @@ -487,6 +487,35 @@ def test_class_weights(): assert_array_almost_equal(clf.intercept_, clfa.intercept_) +def test_class_weight_vs_sample_weight(): + """Check class_weights resemble sample_weights behavior.""" + for clf in (RidgeClassifier, RidgeClassifierCV): + + # Iris is balanced, so no effect expected for using 'balanced' weights + clf1 = clf() + clf1.fit(iris.data, iris.target) + clf2 = clf(class_weight='balanced') + clf2.fit(iris.data, iris.target) + assert_almost_equal(clf1.coef_, clf2.coef_) + + # Inflate importance of class 1, check against user-defined weights + sample_weight = np.ones(iris.target.shape) + sample_weight[iris.target == 1] *= 100 + class_weight = {0: 1., 1: 100., 2: 1.} + clf1 = clf() + clf1.fit(iris.data, iris.target, sample_weight) + clf2 = clf(class_weight=class_weight) + clf2.fit(iris.data, iris.target) + assert_almost_equal(clf1.coef_, clf2.coef_) + + # Check that sample_weight and class_weight are multiplicative + clf1 = clf() + clf1.fit(iris.data, iris.target, sample_weight ** 2) + clf2 = clf(class_weight=class_weight) + clf2.fit(iris.data, iris.target, sample_weight) + assert_almost_equal(clf1.coef_, clf2.coef_) + + def test_class_weights_cv(): # Test class weights for cross validated ridge classifier. X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],