10000 FIX Deep copy criterion in trees to fix concurrency bug (#19580) · scikit-learn/scikit-learn@142115d · GitHub
[go: up one dir, main page]

Skip to content

Commit 142115d

Browse files
samdbriceSamuel Brice
authored andcommitted
FIX Deep copy criterion in trees to fix concurrency bug (#19580)
Co-authored-by: Samuel Brice <samuel.brice@twosigma.com>
1 parent 83a1031 commit 142115d

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

doc/whats_new/v0.24.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ Changelog
4848
:class:`~sklearn.semi_supervised.LabelPropagation`.
4949
:pr:`19271` by :user:`Zhaowei Wang <ThuWangzw>`.
5050

51+
:mod:`sklearn.tree`
52+
.......................
53+
54+
- |Fix| Fix a bug in `fit` of :class:`tree.BaseDecisionTree` that caused
55+
segmentation faults under certain conditions. `fit` now deep copies the
56+
`Criterion` object to prevent shared concurrent accesses.
57+
:pr:`19580` by :user:`Samuel Brice <samdbrice>` and
58+
:user:`Alex Adamson <aadamson>` and
59+
:user:`Wil Yegelwel <wyegelwel>`.
60+
5161
:mod:`sklearn.utils`
5262
....................
5363

sklearn/ensemble/tests/test_forest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,3 +1377,21 @@ def test_little_tree_with_small_max_samples(ForestClass):
13771377

13781378
msg = "Tree without `max_samples` restriction should have more nodes"
13791379
assert tree1.node_count > tree2.node_count, msg
1380+
1381+
1382+
@pytest.mark.parametrize('Forest', FOREST_REGRESSORS)
1383+
def test_mse_criterion_object_segfault_smoke_test(Forest):
1384+
# This is a smoke test to ensure that passing a mutable criterion
1385+
# does not cause a segfault when fitting with concurrent threads.
1386+
# Non-regression test for:
1387+
# https://github.com/scikit-learn/scikit-learn/issues/12623
1388+
from sklearn.tree._criterion import MSE
1389+
1390+
y = y_reg.reshape(-1, 1)
1391+
n_samples, n_outputs = y.shape
1392+
mse_criterion = MSE(n_outputs, n_samples)
1393+
est = FOREST_REGRESSORS[Forest](
1394+
n_estimators=2, n_jobs=2, criterion=mse_criterion
1395+
)
1396+
1397+
est.fit(X_reg, y)

sklearn/tree/_classes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numbers
1818
import warnings
19+
import copy
1920
from abc import ABCMeta
2021
from abc import abstractmethod
2122
from math import ceil
@@ -349,6 +350,10 @@ def fit(self, X, y, sample_weight=None, check_input=True,
349350
else:
350351
criterion = CRITERIA_REG[self.criterion](self.n_outputs_,
351352
n_samples)
353+
else:
354+
# Make a deepcopy in case the criterion has mutable attributes that
355+
# might be shared and modified concurrently during parallel fitting
356+
criterion = copy.deepcopy(criterion)
352357

353358
SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
354359

0 commit comments

Comments
 (0)
0