8000 TST: added test that sample_weight can be a list (#8261) · scikit-learn/scikit-learn@c5bcdde · GitHub
[go: up one dir, main page]

Skip to content

Commit c5bcdde

Browse files
dalmiajnothman
authored andcommitted
TST: added test that sample_weight can be a list (#8261)
1 parent 53f8082 commit c5bcdde

File tree

5 files changed

+33
-5
lines changed

5 files changed

+33
-5
lines changed

sklearn/calibration.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .base import BaseEstimator, ClassifierMixin, RegressorMixin, clone
2020
from .preprocessing import label_binarize, LabelBinarizer
2121
from .utils import check_X_y, check_array, indexable, column_or_1d
22-
from .utils.validation import check_is_fitted
22+
from .utils.validation import check_is_fitted, check_consistent_length
2323
from .utils.fixes import signature
2424
from .isotonic import IsotonicRegression
2525
from .svm import LinearSVC
@@ -167,6 +167,9 @@ def fit(self, X, y, sample_weight=None):
167167
" itself." % estimator_name)
168168
base_estimator_sample_weight = None
169169
else:
170+
if sample_weight is not None:
171+
sample_weight = check_array(sample_weight, ensure_2d=False)
172+
check_consistent_length(y, sample_weight)
170173
base_estimator_sample_weight = sample_weight
171174
for train, test in cv.split(X, y):
172175
this_estimator = clone(base_estimator)

sklearn/ensemble/bagging.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..utils import check_random_state, check_X_y, check_array, column_or_1d
2121
from ..utils.random import sample_without_replacement
2222
from ..utils.validation import has_fit_parameter, check_is_fitted
23-
from ..utils import indices_to_mask
23+
from ..utils import indices_to_mask, check_consistent_length
2424
from ..utils.fixes import bincount
2525
from ..utils.metaestimators import if_delegate_has_method
2626
from ..utils.multiclass import check_classification_targets
@@ -82,8 +82,8 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,
8282

8383
for i in range(n_estimators):
8484
if verbose > 1:
85-
print("Building estimator %d of %d for this parallel run (total %d)..." %
86-
(i + 1, n_estimators, total_n_estimators))
85+
print("Building estimator %d of %d for this parallel run "
86+
"(total %d)..." % (i + 1, n_estimators, total_n_estimators))
8787

8888
random_state = np.random.RandomState(seeds[i])
8989
estimator = ensemble._make_estimator(append=False,
@@ -282,6 +282,9 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
282282

283283
# Convert data
284284
X, y = check_X_y(X, y, ['csr', 'csc'])
285+
if sample_weight is not None:
286+
sample_weight = check_array(sample_weight, ensure_2d=False)
287+
check_consistent_length(y, sample_weight)
285288

286289
# Remap output
287290
n_samples, self.n_features_ = X.shape

sklearn/linear_model/logistic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,9 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
888888
y_test = y[test]
889889

890890
if sample_weight is not None:
891+
sample_weight = check_array(sample_weight, ensure_2d=False)
892+
check_consistent_length(y, sample_weight)
893+
891894
sample_weight = sample_weight[train]
892895

893896
coefs, Cs, n_iter = logistic_regression_path(

sklearn/naive_bayes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from .preprocessing import binarize
2626
from .preprocessing import LabelBinarizer
2727
from .preprocessing import label_binarize
28-
from .utils import check_X_y, check_array
28+
from .utils import check_X_y, check_array, check_consistent_length
2929
from .utils.extmath import safe_sparse_dot, logsumexp
3030
from .utils.multiclass import _check_partial_fit_first_call
3131
from .utils.fixes import in1d
@@ -333,6 +333,9 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
333333
Returns self.
334334
"""
335335
X, y = check_X_y(X, y)
336+
if sample_weight is not None:
337+
sample_weight = check_array(sample_weight, ensure_2d=False)
338+
check_consistent_length(y, sample_weight)
336339

337340
# If the ratio of data variance between dimensions is too small, it
338341
# will cause numerical errors. To address this, we artificially

sklearn/utils/estimator_checks.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _yield_non_meta_checks(name, Estimator):
7272
yield check_fit_score_takes_y
7373
yield check_dtype_object
7474
yield check_sample_weights_pandas_series
75+
yield check_sample_weights_list
7576
yield check_estimators_fit_returns_self
7677

7778
# Check that all estimator yield informative messages when
@@ -396,6 +397,21 @@ def check_sample_weights_pandas_series(name, Estimator):
396397
"input of type pandas.Series to class weight.")
397398

398399

400+
@ignore_warnings(category=DeprecationWarning)
401+
def check_sample_weights_list(name, Estimator):
402+
# check that estimators will accept a 'sample_weight' parameter of
403+
# type list in the 'fit' function.
404+
estimator = Estimator()
405+
if has_fit_parameter(estimator, "sample_weight"):
406+
rnd = np.random.RandomState(0)
407+
X = rnd.uniform(size=(10, 3))
408+
y = np.arange(10) % 3
409+
y = multioutput_estimator_convert_y_2d(name, y)
410+
sample_weight = [3] * 10
411+
# Test that estimators don't raise any exception
412+
estimator.fit(X, y, sample_weight=sample_weight)
413+
414+
399415
@ignore_warnings(category=(DeprecationWarning, UserWarning))
400416
def check_dtype_object(name, Estimator):
401417
# check that estimators treat dtype object as numeric if possible

0 commit comments

Comments
 (0)
0