8000 Add sample_weight support to Dummy Regressor · scikit-learn/scikit-learn@da31344 · GitHub
[go: up one dir, main page]

Skip to content

Commit da31344

Browse files
committed
Add sample_weight support to Dummy Regressor
1 parent b698d9f commit da31344

File tree

4 files changed

+62
-20
lines changed

4 files changed

+62
-20
lines changed

sklearn/dummy.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .utils.validation import check_consistent_length
1515
from .utils import deprecated
1616
from .utils.random import random_choice_csc
17+
from .utils.stats import _weighted_percentile
1718
from .utils.multiclass import class_distribution
1819

1920

@@ -366,7 +367,7 @@ def y_mean_(self):
366367
return self.constant_
367368
raise AttributeError
368369

369-
def fit(self, X, y):
370+
def fit(self, X, y, sample_weight=None):
370371
"""Fit the random regressor.
371372
372373
Parameters
@@ -378,6 +379,9 @@ def fit(self, X, y):
378379
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
379380
Target values.
380381
382+
sample_weight : array-like of shape = [n_samples], optional
383+
Sample weights.
384+
381385
Returns
382386
-------
383387
self : object
@@ -389,25 +393,43 @@ def fit(self, X, y):
389393
"'mean', 'median', 'quantile' or 'constant'"
390394
% self.strategy)
391395

392-
y = check_array(y, accept_sparse='csr', ensure_2d=False)
396+
y = check_array(y, ensure_2d=False)
393397
if len(y) == 0:
394398
raise ValueError("y must not be empty.")
395-
self.output_2d_ = (y.ndim == 2)
396399

397-
check_consistent_length(X, y)
400+
self.output_2d_ = y.ndim == 2
401+
if y.ndim == 1:
402+
y = np.reshape(y, (-1, 1))
403+
self.n_outputs_ = y.shape[1]
404+
405+
check_consistent_length(X, y, sample_weight)
398406

399407
if self.strategy == "mean":
400-
self.constant_ = np.mean(y, axis=0)
408+
if sample_weight is None:
409+
self.constant_ = np.mean(y, axis=0)
410+
else:
411+
self.constant_ = np.average(y, axis=0, weights=sample_weight)
401412

402413
elif self.strategy == "median":
403-
self.constant_ = np.median(y, axis=0)
414+
if sample_weight is None:
415+
self.constant_ = np.median(y, axis=0)
416+
else:
417+
self.constant_ = [_weighted_percentile(y[:, k], sample_weight,
418+
percentile=50.)
419+
for k in range(self.n_outputs_)]
404420

405421
elif self.strategy == "quantile":
406422
if self.quantile is None or not np.isscalar(self.quantile):
407423
raise ValueError("Quantile must be a scalar in the range "
408424
"[0.0, 1.0], but got %s." % self.quantile)
409425

410-
self.constant_ = np.percentile(y, axis=0, q=self.quantile * 100.0)
426+
percentile = self.quantile * 100.0
427+
if sample_weight is None:
428+
self.constant_ = np.percentile(y, axis=0, q=percentile)
429+
else:
430+
self.constant_ = [_weighted_percentile(y[:, k], sample_weight,
431+
percentile=percentile)
432+
for k in range(self.n_outputs_)]
411433

412434
elif self.strategy == "constant":
413435
if self.constant is None:
@@ -426,7 +448,6 @@ def fit(self, X, y):
426448
self.constant_ = self.constant
427449

428450
self.constant_ = np.reshape(self.constant_, (1, -1))
429-
self.n_outputs_ = np.size(self.constant_) # y.shape[1] is not safe
430451
return self
431452

432453
def predict(self, X):

sklearn/ensemble/gradient_boosting.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from ..base import RegressorMixin
3838
from ..utils import check_random_state, check_array, check_X_y, column_or_1d
3939
from ..utils.extmath import logsumexp
40+
from ..utils.stats import _weighted_percentile
4041
from ..externals import six
4142
from ..feature_selection.from_model import _LearntSelectorMixin
4243

@@ -50,18 +51,6 @@
5051
from ._gradient_boosting import _random_sample_mask
5152

5253

53-
def _weighted_percentile(array, sample_weight, percentile=50):
54-
"""Compute the weighted ``percentile`` of ``array`` with ``sample_weight``. """
55-
sorted_idx = np.argsort(array)
56-
57-
# Find index of median prediction for each sample
58-
weight_cdf = sample_weight[sorted_idx].cumsum()
59-
percentile_or_above = weight_cdf >= (percentile / 100.0) * weight_cdf[-1]
60-
percentile_idx = percentile_or_above.argmax()
61-
62-
return array[sorted_idx[percentile_idx]]
63-
64-
6554
class QuantileEstimator(BaseEstimator):
6655
"""An estimator predicting the alpha-quantile of the training targets."""
6756
def __init__(self, alpha=0.9):

sklearn/tests/test_dummy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sklearn.utils.testing import assert_raises
1212
from sklearn.utils.testing import assert_true
1313
from sklearn.utils.testing import assert_warns_message
14+
from sklearn.utils.stats import _weighted_percentile
1415

1516
from sklearn.dummy import DummyClassifier, DummyRegressor
1617

@@ -572,6 +573,24 @@ def test_most_frequent_strategy_sparse_target():
572573
np.zeros((n_samples, 1))]))
573574

574575

576+
def test_dummy_regressor_sample_weight(n_samples=10):
577+
random_state = np.random.RandomState(seed=1)
578+
579+
X = [[0]] * n_samples
580+
y = random_state.rand(n_samples)
581+
sample_weight = random_state.rand(n_samples)
582+
583+
est = DummyRegressor(strategy="mean").fit(X, y, sample_weight)
584+
assert_equal(est.constant_, np.average(y, weights=sample_weight))
585+
586+
est = DummyRegressor(strategy="median").fit(X, y, sample_weight)
587+
assert_equal(est.constant_, _weighted_percentile(y, sample_weight, 50.))
588+
589+
est = DummyRegressor(strategy="quantile", quantile=.95).fit(X, y,
590+
sample_weight)
591+
assert_equal(est.constant_, _weighted_percentile(y, sample_weight, 95.))
592+
593+
575594
if __name__ == '__main__':
576595
import nose
577596
nose.runmodule()

sklearn/utils/stats.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,16 @@ def _rankdata(a, method="average"):
4444

4545
except TypeError as e:
4646
rankdata = _rankdata
47+
48+
49+
def _weighted_percentile(array, sample_weight, percentile=50):
50+
"""Compute the weighted ``percentile`` of ``array`` with ``sample_weight``. """
51+
sorted_idx = np.argsort(array)
52+
53+
# Find index of median prediction for each sample
54+
weight_cdf = sample_weight[sorted_idx].cumsum()
55+
percentile_or_above = weight_cdf >= (percentile / 100.0) * weight_cdf[-1]
56+
percentile_idx = percentile_or_above.argmax()
57+
58+
return array[sorted_idx[percentile_idx]]
59+

0 commit comments

Comments
 (0)
0