8000 [MRG + 1] Issue#8062: JoblibException thrown when passing "fit_params… · Sundrique/scikit-learn@a7e6bc0 · GitHub
[go: up one dir, main page]

Skip to content

Commit a7e6bc0

Browse files
xorSundrique
authored andcommitted
[MRG + 1] Issue#8062: JoblibException thrown when passing "fit_params={'sample_… (scikit-learn#8068)
* Issue#8062: JoblibException thrown when passing "fit_params={'sample_weights': weights}" to RandomizedSearchCV with RandomForestClassifier * Added test for issues scikit-learn#8068 and scikit-learn#8064. * Clean up with pyflakes. * Changed cryptic comment.
1 parent 15c39fe commit a7e6bc0

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

sklearn/ensemble/forest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def fit(self, X, y, sample_weight=None):
245245
# Validate or convert input data
246246
X = check_array(X, accept_sparse="csc", dtype=DTYPE)
247247
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
248+
if sample_weight is not None:
249+
sample_weight = check_array(sample_weight, ensure_2d=False)
248250
if issparse(X):
249251
# Pre-sort indices to avoid that each individual tree of the
250252
# ensemble sorts the indices.

sklearn/ensemble/tests/test_forest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,12 @@ def check_class_weights(name):
947947
clf2.fit(iris.data, iris.target, sample_weight)
948948
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
949949

950+
# Using a Python 2.x list as the sample_weight parameter used to raise
951+
# an exception. This test makes sure such code will now run correctly.
952+
clf = ForestClassifier()
953+
sample_weight = [1.] * len(iris.data)
954+
clf.fit(iris.data, iris.target, sample_weight=sample_weight)
955+
950956

951957
def test_class_weights():
952958
for name in FOREST_CLASSIFIERS:

0 commit comments

Comments
 (0)
0