8000 FIX ignore null weight when computing estimator error in AdaBoostRegr… · scikit-learn/scikit-learn@1888a96 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1888a96

Browse files
glemaitreadrinjalali
authored andcommitted
FIX ignore null weight when computing estimator error in AdaBoo 8000 stRegressor (#14294)
* FIX normalize with max of samples with non-null weights in AdaBoostRegressor * iter * update PR number * PEP8 * iter * iter * ignore line coverage * FIX use _check_sample_weight to validate sample_weight * address jeremie comments * fix inplace/mask copy operation * add default value for wrapper * PEP8 * Apply suggestions from code review Co-Authored-By: jeremiedbb <34657725+jeremiedbb@users.noreply.github.com> * apply jeremie comments * increase coverage * PEP8 * address adrin comments * change import for mocking * fix * PEP8 * address comment adrin
1 parent 846e6a3 commit 1888a96

File tree

6 files changed

+130
-67
lines changed

6 files changed

+130
-67
lines changed

doc/whats_new/v0.22.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ Changelog
255255
now raise consistent error messages.
256256
:pr:`15084` by :user:`Guillaume Lemaitre <glemaitre>`.
257257

258+
- |Fix| :class:`ensemble.AdaBoostRegressor` where the loss should be normalized
259+
by the max of the samples with non-null weights only.
260+
:pr:`14294` by :user:`Guillaume Lemaitre <glemaitre>`.
261+
258262
:mod:`sklearn.feature_extraction`
259263
.................................
260264

sklearn/ensemble/_weight_boosting.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..utils.extmath import stable_cumsum
3939
from ..metrics import accuracy_score, r2_score
4040
from ..utils.validation import check_is_fitted
41+
from ..utils.validation import _check_sample_weight
4142
from ..utils.validation import has_fit_parameter
4243
from ..utils.validation import _num_samples
4344

@@ -117,20 +118,10 @@ def fit(self, X, y, sample_weight=None):
117118

118119
X, y = self._validate_data(X, y)
119120

120-
if sample_weight is None:
121-
# Initialize weights to 1 / n_samples
122-
sample_weight = np.empty(_num_samples(X), dtype=np.float64)
123-
sample_weight[:] = 1. / _num_samples(X)
124-
else:
125-
sample_weight = check_array(sample_weight, ensure_2d=False)
126-
# Normalize existing weights
127-
sample_weight = sample_weight / sample_weight.sum(dtype=np.float64)
128-
129-
# Check that the sample weights sum is positive
130-
if sample_weight.sum() <= 0:
131-
raise ValueError(
132-
"Attempting to fit with a non-positive "
133-
"weighted number of samples.")
121+
sample_weight = _check_sample_weight(sample_weight, X, np.float64)
122+
sample_weight /= sample_weight.sum()
123+
if np.any(sample_weight < 0):
124+
raise ValueError("sample_weight cannot contain negative weights")
134125

135126
# Check parameters
136127
self._validate_estimator()
@@ -1029,13 +1020,10 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
10291020
estimator = self._make_estimator(random_state=random_state)
10301021

10311022
# Weighted sampling of the training set with replacement
1032-
# For NumPy >= 1.7.0 use np.random.choice
1033-
cdf = stable_cumsum(sample_weight)
1034-
cdf /= cdf[-1]
1035-
uniform_samples = random_state.random_sample(_num_samples(X))
1036-
bootstrap_idx = cdf.searchsorted(uniform_samples, side='right')
1037-
# searchsorted returns a scalar
1038-
bootstrap_idx = np.array(bootstrap_idx, copy=False)
1023+
bootstrap_idx = random_state.choice(
1024+
np.arange(_num_samples(X)), size=_num_samples(X), replace=True,
1025+
p=sample_weight
1026+
)
10391027

10401028
# Fit on the bootstrapped sample and obtain a prediction
10411029
# for all samples in the training set
@@ -1045,18 +1033,21 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
10451033
y_predict = estimator.predict(X)
10461034

10471035
error_vect = np.abs(y_predict - y)
1048-
error_max = error_vect.max()
1036+
sample_mask = sample_weight > 0
1037+
masked_sample_weight = sample_weight[sample_mask]
1038+
masked_error_vector = error_vect[sample_mask]
10491039

1050-
if error_max != 0.:
1051-
error_vect /= error_max
1040+
error_max = masked_error_vector.max()
1041+
if error_max != 0:
1042+
masked_error_vector /= error_max
10521043

10531044
if self.loss == 'square':
1054-
error_vect **= 2
1045+
masked_error_vector **= 2
10551046
elif self.loss == 'exponential':
1056-
error_vect = 1. - np.exp(- error_vect)
1047+
masked_error_vector = 1. - np.exp(-masked_error_vector)
10571048

10581049
# Calculate the average loss
1059-
estimator_error = (sample_weight * error_vect).sum()
1050+
estimator_error = (masked_sample_weight * masked_error_vector).sum()
10601051

10611052
if estimator_error <= 0:
10621053
# Stop if fit is perfect
@@ -1074,9 +1065,9 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
10741065
estimator_weight = self.learning_rate * np.log(1. / beta)
10751066

10761067
if not iboost == self.n_estimators - 1:
1077-
sample_weight *= np.power(
1078-
beta,
1079-
(1. - error_vect) * self.learning_rate)
1068+
sample_weight[sample_mask] *= np.power(
1069+
beta, (1. - masked_error_vector) * self.learning_rate
1070+
)
10801071

10811072
return sample_weight, estimator_weight, estimator_error
10821073

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sklearn.metrics import mean_squared_error
2525
from sklearn.model_selection import train_test_split
2626
from sklearn.utils import check_random_state, tosequence
27+
from sklearn.utils._mocking import NoSampleWeightWrapper
2728
from sklearn.utils.testing import assert_almost_equal
2829
from sklearn.utils.testing import assert_array_almost_equal
2930
from sklearn.utils.testing import assert_array_equal
@@ -1292,20 +1293,6 @@ def test_early_stopping_stratified():
12921293
gbc.fit(X, y)
12931294

12941295

1295-
class _NoSampleWeightWrapper(BaseEstimator):
1296-
def __init__(self, est):
1297-
self.est = est
1298-
1299-
def fit(self, X, y):
1300-
self.est.fit(X, y)
1301-
1302-
def predict(self, X):
1303-
return self.est.predict(X)
1304-
1305-
def predict_proba(self, X):
1306-
return self.est.predict_proba(X)
1307-
1308-
13091296
def _make_multiclass():
13101297
return make_classification(n_classes=3, n_clusters_per_class=1)
13111298

@@ -1330,7 +1317,7 @@ def test_gradient_boosting_with_init(gb, dataset_maker, init_estimator):
13301317
gb(init=init_est).fit(X, y, sample_weight=sample_weight)
13311318

13321319
# init does not support sample weights
1333-
init_est = _NoSampleWeightWrapper(init_estimator())
1320+
init_est = NoSampleWeightWrapper(init_estimator())
13341321
gb(init=init_est).fit(X, y) # ok no sample weights
13351322
with pytest.raises(ValueError,
13361323
match="estimator.*does not support sample weights"):

sklearn/ensemble/tests/test_weight_boosting.py

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,29 @@
33
import numpy as np
44
import pytest
55

6+
from scipy.sparse import csc_matrix
7+
from scipy.sparse import csr_matrix
8+
from scipy.sparse import coo_matrix
9+
from scipy.sparse import dok_matrix
10+
from scipy.sparse import lil_matrix
11+
612
from sklearn.utils.testing import assert_array_equal, assert_array_less
713
from sklearn.utils.testing import assert_array_almost_equal
814
from sklearn.utils.testing import assert_raises, assert_raises_regexp
915

1016
from sklearn.base import BaseEstimator
17+
from sklearn.base import clone
18+
from sklearn.dummy import DummyClassifier, DummyRegressor
19+
from sklearn.linear_model import LinearRegression
1120
from sklearn.model_selection import train_test_split
1221
from sklearn.model_selection import GridSearchCV
1322
from sklearn.ensemble import AdaBoostClassifier
1423
from sklearn.ensemble import AdaBoostRegressor
1524
from sklearn.ensemble._weight_boosting import _samme_proba
16-
from scipy.sparse import csc_matrix
17-
from scipy.sparse import csr_matrix
18-
from scipy.sparse import coo_matrix
19-
from scipy.sparse import dok_matrix
20-
from scipy.sparse import lil_matrix
2125
from sklearn.svm import SVC, SVR
2226
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2327
from sklearn.utils import shuffle
28+
from sklearn.utils._mocking import NoSampleWeightWrapper
2429
from sklearn import datasets
2530

2631

@@ -137,9 +142,10 @@ def test_iris():
137142
np.abs(clf_samme.predict_proba(iris.data) - prob_samme))
138143

139144

140-
def test_boston():
145+
@pytest.mark.parametrize('loss', ['linear', 'square', 'exponential'])
146+
def test_boston(loss):
141147
# Check consistency on dataset boston house prices.
142-
reg = AdaBoostRegressor(random_state=0)
148+
reg = AdaBoostRegressor(loss=loss, random_state=0)
143149
reg.fit(boston.data, boston.target)
144150
score = reg.score(boston.data, boston.target)
145151
assert score > 0.85
@@ -304,16 +310,6 @@ def test_base_estimator():
304310
clf.fit, X_fail, y_fail)
305311

306312

307-
def test_sample_weight_missing():
308-
from sklearn.cluster import KMeans
309-
310-
clf = AdaBoostClassifier(KMeans(), algorithm="SAMME")
311-
assert_raises(ValueError, clf.fit, X, y_regr)
312-
313-
clf = AdaBoostRegressor(KMeans())
314-
assert_raises(ValueError, clf.fit, X, y_regr)
315-
316-
317313
def test_sparse_classification():
318314
# Check classification with sparse input.
319315

@@ -486,9 +482,6 @@ def test_multidimensional_X():
486482
Check that the AdaBoost estimators can work with n-dimensional
487483
data matrix
488484
"""
489-
490-
from sklearn.dummy import DummyClassifier, DummyRegressor
491-
492485
rng = np.random.RandomState(0)
493486

494487
X = rng.randn(50, 3, 3)
@@ -505,6 +498,56 @@ def test_multidimensional_X():
505498
boost.predict(X)
506499

507500

501+
@pytest.mark.parametrize("algorithm", ['SAMME', 'SAMME.R'])
502+
def test_adaboostclassifier_without_sample_weight(algorithm):
503+
X, y = iris.data, iris.target
504+
base_estimator = NoSampleWeightWrapper(DummyClassifier())
505+
clf = AdaBoostClassifier(
506+
base_estimator=base_estimator, algorithm=algorithm
507+
)
508+
err_msg = ("{} doesn't support sample_weight"
509+
.format(base_estimator.__class__.__name__))
510+
with pytest.raises(ValueError, match=err_msg):
511+
clf.fit(X, y)
512+
513+
514+
def test_adaboostregressor_sample_weight():
515+
# check that giving weight will have an influence on the error computed
516+
# for a weak learner
517+
rng = np.random.RandomState(42)
518+
X = np.linspace(0, 100, num=1000)
519+
y = (.8 * X + 0.2) + (rng.rand(X.shape[0]) * 0.0001)
520+
X = X.reshape(-1, 1)
521+
522+
# add an arbitrary outlier
523+
X[-1] *= 10
524+
y[-1] = 10000
525+
526+
# random_state=0 ensure that the underlying boostrap will use the outlier
527+
regr_no_outlier = AdaBoostRegressor(
528+
base_estimator=LinearRegression(), n_estimators=1, random_state=0
529+
)
530+
regr_with_weight = clone(regr_no_outlier)
531+
regr_with_outlier = clone(regr_no_outlier)
532+
533+
# fit 3 models:
534+
# - a model containing the outlier
535+
# - a model without the outlier
536+
# - a model containing the outlier but with a null sample-weight
537+
regr_with_outlier.fit(X, y)
538+
regr_no_outlier.fit(X[:-1], y[:-1])
539+
sample_weight = np.ones_like(y)
540+
sample_weight[-1] = 0
541+
regr_with_weight.fit(X, y, sample_weight=sample_weight)
542+
543+
score_with_outlier = regr_with_outlier.score(X[:-1], y[:-1])
544+
score_no_outlier = regr_no_outlier.score(X[:-1], y[:-1])
545+
score_with_weight = regr_with_weight.score(X[:-1], y[:-1])
546+
547+
assert score_with_outlier < score_no_outlier
548+
assert score_with_outlier < score_with_weight
549+
assert score_no_outlier == pytest.approx(score_with_weight)
550+
508551
@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
509552
def test_adaboost_consistent_predict(algorithm):
510553
# check that predict_proba and predict give consistent results
@@ -520,3 +563,17 @@ def test_adaboost_consistent_predict(algorithm):
520563
np.argmax(model.predict_proba(X_test), axis=1),
521564
model.predict(X_test)
522565
)
566+
567+
568+
@pytest.mark.parametrize(
569+
'model, X, y',
570+
[(AdaBoostClassifier(), iris.data, iris.target),
571+
(AdaBoostRegressor(), boston.data, boston.target)]
572+
)
573+
def test_adaboost_negative_weight_error(model, X, y):
574+
sample_weight = np.ones_like(y)
575+
sample_weight[-1] = -10
576+
577+
err_msg = "sample_weight cannot contain negative weight"
578+
with pytest.raises(ValueError, match=err_msg):
579+
model.fit(X, y, sample_weight=sample_weight)

sklearn/utils/_mocking.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,27 @@ def score(self, X=None, Y=None):
135135

136136
def _more_tags(self):
137137
return {'_skip_test': True, 'X_types': ['1dlabel']}
138+
139+
140+
class NoSampleWeightWrapper(BaseEstimator):
141+
"""Wrap estimator which will not expose `sample_weight`.
142+
143+
Parameters
144+
----------
145+
est : estimator, default=None
146+
The estimator to wrap.
147+
"""
148+
def __init__(self, est=None):
149+
self.est = est
150+
151+
def fit(self, X, y):
152+
return self.est.fit(X, y)
153+
154+
def predict(self, X):
155+
return self.est.predict(X)
156+
157+
def predict_proba(self, X):
158+
return self.est.predict_proba(X)
159+
160+
def _more_tags(self):
161+
return {'_skip_test': True} # pragma: no cover

sklearn/utils/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,8 +1062,8 @@ def _check_sample_weight(sample_weight, X, dtype=None):
10621062
if dtype is None:
10631063
dtype = [np.float64, np.float32]
10641064
sample_weight = check_array(
1065-
sample_weight, accept_sparse=False,
1066-
ensure_2d=False, dtype=dtype, order="C"
1065+
sample_weight, accept_sparse=False, ensure_2d=False, dtype=dtype,
1066+
order="C"
10671067
)
10681068
if sample_weight.ndim != 1:
10691069
raise ValueError("Sample weights must be 1D array or scalar")

0 commit comments

Comments
 (0)
0