8000 ENH add `only_non_negative` parameter to `_check_sample_weight` (#20880) · baam25simo/scikit-learn@9c84abd · GitHub
[go: up one dir, main page]

Skip to content

Commit 9c84abd

Browse files
simonandrasglemaitreogrisel
authored
ENH add only_non_negative parameter to _check_sample_weight (scikit-learn#20880)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent ac8a2cd commit 9c84abd

File tree

8 files changed

+53
-22
lines changed

8 files changed

+53
-22
lines changed

doc/whats_new/v1.1.rst

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,20 @@ Changelog
3838
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
3939
where 123456 is the *pull request* number, not the issue number.
4040
41-
42-
:mod:`sklearn.decomposition`
43-
............................
41+
:mod:`sklearn.utils`
42+
....................
43+
44+
- |Enhancement| :func:`utils.validation._check_sample_weight` can perform a
45+
non-negativity check on the sample weights. It can be turned on
46+
using the only_non_negative bool parameter.
47+
Estimators that check for non-negative weights are updated:
48+
:func:`linear_model.LinearRegression` (here the previous
49+
error message was misleading),
50+
:func:`ensemble.AdaBoostClassifier`,
51+
:func:`ensemble.AdaBoostRegressor`,
52+
:func:`neighbors.KernelDensity`.
53+
:pr:`20880` by :user:`Guillaume Lemaitre <glemaitre>`
54+
and :user:`András Simon <simonandras>`.
4455

4556

4657
Code and Documentation Contributors

sklearn/ensemble/_weight_boosting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ def fit(self, X, y, sample_weight=None):
123123
y_numeric=is_regressor(self),
124124
)
125125

126-
sample_weight = _check_sample_weight(sample_weight, X, np.float64, copy=True)
126+
sample_weight = _check_sample_weight(
127+
sample_weight, X, np.float64, copy=True, only_non_negative=True
128+
)
127129
sample_weight /= sample_weight.sum()
128-
if np.any(sample_weight < 0):
129-
raise ValueError("sample_weight cannot contain negative weights")
130130

131131
# Check parameters
132132
self._validate_estimator()
@@ -136,7 +136,7 @@ def fit(self, X, y, sample_weight=None):
136136
self.estimator_weights_ = np.zeros(self.n_estimators, dtype=np.float64)
137137
self.estimator_errors_ = np.ones(self.n_estimators, dtype=np.float64)
138138

139-
# Initializion of the random number instance that will be used to
139+
# Initialization of the random number instance that will be used to
140140
# generate a seed at each iteration
141141
random_state = check_random_state(self.random_state)
142142

sklearn/ensemble/tests/test_weight_boosting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,6 @@ def test_adaboost_negative_weight_error(model, X, y):
576576
sample_weight = np.ones_like(y)
577577
sample_weight[-1] = -10
578578

579-
err_msg = "sample_weight cannot contain negative weight"
579+
err_msg = "Negative values in data passed to `sample_weight`"
580580
with pytest.raises(ValueError, match=err_msg):
581581
model.fit(X, y, sample_weight=sample_weight)

sklearn/linear_model/_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,9 @@ def fit(self, X, y, sample_weight=None):
663663
)
664664

665665
if sample_weight is not None:
666-
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
666+
sample_weight = _check_sample_weight(
667+
sample_weight, X, dtype=X.dtype, only_non_negative=True
668+
)
667669

668670
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
669671
X,

sklearn/neighbors/_kde.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def fit(self, X, y=None, sample_weight=None):
191191
X = self._validate_data(X, order="C", dtype=DTYPE)
192192

193193
if sample_weight is not None:
194-
sample_weight = _check_sample_weight(sample_weight, X, DTYPE)
195-
if sample_weight.min() <= 0:
196-
raise ValueError("sample_weight must have positive values")
194+
sample_weight = _check_sample_weight(
195+
sample_weight, X, DTYPE, only_non_negative=True
196+
)
197197

198198
kwargs = self.metric_params
199199
if kwargs is None:

sklearn/neighbors/tests/test_kde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def test_sample_weight_invalid():
209209
data = np.reshape([1.0, 2.0, 3.0], (-1, 1))
210210

211211
sample_weight = [0.1, -0.2, 0.3]
212-
expected_err = "sample_weight must have positive values"
212+
expected_err = "Negative values in data passed to `sample_weight`"
213213
with pytest.raises(ValueError, match=expected_err):
214214
kde.fit(data, sample_weight=sample_weight)
215215

sklearn/utils/tests/test_validation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
FLOAT_DTYPES,
5353
_get_feature_names,
5454
_check_feature_names_in,
55+
_check_fit_params,
5556
)
56-
from sklearn.utils.validation import _check_fit_params
5757
from sklearn.base import BaseEstimator
5858
import sklearn
5959

@@ -1253,6 +1253,14 @@ def test_check_sample_weight():
12531253
sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
12541254
assert sample_weight.dtype == np.float64
12551255

1256+
# check negative weight when only_non_negative=True
1257+
X = np.ones((5, 2))
1258+
sample_weight = np.ones(_num_samples(X))
1259+
sample_weight[-1] = -10
1260+
err_msg = "Negative values in data passed to `sample_weight`"
1261+
with pytest.raises(ValueError, match=err_msg):
1262+
_check_sample_weight(sample_weight, X, only_non_negative=True)
1263+
12561264

12571265
@pytest.mark.parametrize("toarray", [np.array, sp.csr_matrix, sp.csc_matrix])
12581266
def test_allclose_dense_sparse_equals(toarray):

sklearn/utils/validation.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,7 +1492,9 @@ def _check_psd_eigenvalues(lambdas, enable_warnings=False):
14921492
return lambdas
14931493

14941494

1495-
def _check_sample_weight(sample_weight, X, dtype=None, copy=False):
1495+
def _check_sample_weight(
1496+
sample_weight, X, dtype=None, copy=False, only_non_negative=False
1497+
):
14961498
"""Validate sample weights.
14971499
14981500
Note that passing sample_weight=None will output an array of ones.
@@ -1503,25 +1505,30 @@ def _check_sample_weight(sample_weight, X, dtype=None, copy=False):
15031505
Parameters
15041506
----------
15051507
sample_weight : {ndarray, Number or None}, shape (n_samples,)
1506-
Input sample weights.
1508+
Input sample weights.
15071509
15081510
X : {ndarray, list, sparse matrix}
15091511
Input data.
15101512
1513+
only_non_negative : bool, default=False,
1514+
Whether or not the weights are expected to be non-negative.
1515+
1516+
.. versionadded:: 1.0
1517+
15111518
dtype : dtype, default=None
1512-
dtype of the validated `sample_weight`.
1513-
If None, and the input `sample_weight` is an array, the dtype of the
1514-
input is preserved; otherwise an array with the default numpy dtype
1515-
is be allocated. If `dtype` is not one of `float32`, `float64`,
1516-
`None`, the output will be of dtype `float64`.
1519+
dtype of the validated `sample_weight`.
1520+
If None, and the input `sample_weight` is an array, the dtype of the
1521+
input is preserved; otherwise an array with the default numpy dtype
1522+
is be allocated. If `dtype` is not one of `float32`, `float64`,
1523+
`None`, the output will be of dtype `float64`.
15171524
15181525
copy : bool, default=False
15191526
If True, a copy of sample_weight will be created.
15201527
15211528
Returns
15221529
-------
15231530
sample_weight : ndarray of shape (n_samples,)
1524-
Validated sample weight. It is guaranteed to be "C" contiguous.
1531+
Validated sample weight. It is guaranteed to be "C" contiguous.
15251532
"""
15261533
n_samples = _num_samples(X)
15271534

@@ -1553,6 +1560,9 @@ def _check_sample_weight(sample_weight, X, dtype=None, copy=False):
15531560
)
15541561
)
15551562

1563+
if only_non_negative:
1564+
check_non_negative(sample_weight, "`sample_weight`")
1565+
15561566
return sample_weight
15571567

15581568

0 commit comments

Comments
 (0)
0