-
-
Notifications
You must be signed in to change notification settings - Fork 26k
MAINT Common sample_weight validation #14307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9800b4b
95df187
c889db3
bd52cfc
9e108a4
4ff292d
bb64a9b
29e4ff6
4d7bb15
1c0f6a7
59abc05
84b0ac0
4ea0694
cfc7a97
908bbfc
c6280b6
2b84f90
d81fec1
b2b1773
22f9275
c28226a
ed2dc69
0ce20bc
380d9eb
3fa5f73
08e204f
44d99c1
3fc9d1a
22e1070
561bb6a
71ecf65
e244ad5
13f9dec
9cccaf6
fb22cfc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
6D40
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,8 @@ | |
from ..base import BaseEstimator, RegressorMixin | ||
from .base import LinearModel | ||
from ..utils import check_X_y | ||
from ..utils import check_consistent_length | ||
from ..utils import axis0_safe_slice | ||
from ..utils.validation import _check_sample_weight | ||
from ..utils.extmath import safe_sparse_dot | ||
from ..utils.optimize import _check_optimize_result | ||
|
||
|
@@ -255,11 +255,8 @@ def fit(self, X, y, sample_weight=None): | |
X, y = check_X_y( | ||
X, y, copy=False, accept_sparse=['csr'], y_numeric=True, | ||
dtype=[np.float64, np.float32]) | ||
if sample_weight is not None: | ||
sample_weight = np.array(sample_weight) | ||
check_consistent_length(y, sample_weight) | ||
else: | ||
sample_weight = np.ones_like(y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm pretty sure this didn't produce sample weights with the expected dtype when |
||
|
||
sample_weight = _check_sample_weight(sample_weight, X) | ||
|
||
if self.epsilon < 1.0: | ||
raise ValueError( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
from sklearn.utils.testing import SkipTest | ||
from sklearn.utils.testing import assert_array_equal | ||
from sklearn.utils.testing import assert_allclose_dense_sparse | ||
from sklearn.utils.testing import assert_allclose | ||
from sklearn.utils import as_float_array, check_array, check_symmetric | ||
from sklearn.utils import check_X_y | ||
from sklearn.utils import deprecated | ||
|
@@ -39,7 +40,8 @@ | |
check_memory, | ||
check_non_negative, | ||
_num_samples, | ||
check_scalar) | ||
check_scalar, | ||
_check_sample_weight) | ||
import sklearn | ||
|
||
from sklearn.exceptions import NotFittedError | ||
|
@@ -853,3 +855,40 @@ def test_check_scalar_invalid(x, target_name, target_type, min_val, max_val, | |
min_val=min_val, max_val=max_val) | ||
assert str(raised_error.value) == str(err_msg) | ||
assert type(raised_error.value) == type(err_msg) | ||
|
||
|
||
def test_check_sample_weight(): | ||
rth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# check array order | ||
sample_weight = np.ones(10)[::2] | ||
assert not sample_weight.flags["C_CONTIGUOUS"] | ||
sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1))) | ||
assert sample_weight.flags["C_CONTIGUOUS"] | ||
|
||
# check None input | ||
sample_weight = _check_sample_weight(None, X=np.ones((5, 2))) | ||
assert_allclose(sample_weight, np.ones(5)) | ||
|
||
# check numbers input | ||
sample_weight = _check_sample_weight(2.0, X=np.ones((5, 2))) | ||
assert_allclose(sample_weight, 2 * np.ones(5)) | ||
|
||
# check wrong number of dimensions | ||
with pytest.raises(ValueError, | ||
match="Sample weights must be 1D array or scalar"): | ||
_check_sample_weight(np.ones((2, 4)), X=np.ones((2, 2))) | ||
|
||
# check incorrect n_samples | ||
msg = r"sample_weight.shape == \(4,\), expected \(2,\)!" | ||
with pytest.raises(ValueError, match=msg): | ||
_check_sample_weight(np.ones(4), X=np.ones((2, 2))) | ||
|
||
# float32 dtype is preserved | ||
X = np.ones((5, 2)) | ||
sample_weight = np.ones(5, dtype=np.float32) | ||
sample_weight = _check_sample_weight(sample_weight, X) | ||
assert sample_weight.dtype == np.float32 | ||
|
||
# int dtype will be converted to float64 instead | ||
X = np.ones((5, 2), dtype=np.int) | ||
sample_weight = _check_sample_weight(None, X, dtype=X.dtype) | ||
assert sample_weight.dtype == np.float64 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a test for the shape difference error? (it's implicitly tested in kmeans) |
Uh oh!
There was an error while loading. Please reload this page.