8000 Copy the parameters in _make_estimator to prevent segfault caused by … · scikit-learn/scikit-learn@bb91401 · GitHub
[go: up one dir, main page]

Skip to content

Commit bb91401

Browse files
committed
Copy the parameters in _make_estimator to prevent segfault caused by concurrent accesses to the same Criterion instance
1 parent 4773f3e commit bb91401

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

sklearn/ensemble/_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# License: BSD 3 clause
55

66
from abc import ABCMeta, abstractmethod
7+
import copy
78
import numbers
89
from typing import List
910

@@ -148,7 +149,7 @@ def _make_estimator(self, append=True, random_state=None):
148149
sub-estimators.
149150
"""
150151
estimator = clone(self.base_estimator_)
151-
estimator.set_params(**{p: getattr(self, p)
152+
estimator.set_params(**{p: copy.deepcopy(getattr(self, p))
152153
for p in self.estimator_params})
153154

154155
if random_state is not None:

sklearn/ensemble/tests/test_forest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,3 +1377,23 @@ def test_little_tree_with_small_max_samples(ForestClass):
13771377

13781378
msg = "Tree without `max_samples` restriction should have more nodes"
13791379
assert tree1.node_count > tree2.node_count, msg
1380+
1381+
1382+
@pytest.mark.parametrize(
1383+
'name', FOREST_REGRESSORS
1384+
)
1385+
def test_mse_criterion_object_segfault_smoke_test(name):
1386+
# This test exists to verify that we can successfully run all forest
1387+
# regressors with a criterion object without triggering a seg fault
1388+
# (#12623)
1389+
from sklearn.tree._classes import CRITERIA_REG
1390+
1391+
X = np.random.random((1000, 3))
1392+
y = np.random.random((1000, 1))
1393+
1394+
n_samples, n_outputs = y.shape
1395+
mse_criterion = CRITERIA_REG['mse'](n_outputs, n_samples)
1396+
est = FOREST_REGRESSORS[name](n_estimators=2, n_jobs=-1,
1397+
criterion=mse_criterion)
1398+
1399+
est.fit(X, y)

0 commit comments

Comments
 (0)
0