8000 Deep copy the Criterion instance within BaseDecisionTree.fit to preve… · scikit-learn/scikit-learn@3ff8dc2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3ff8dc2

Browse files
author
Samuel Brice
committed
Deep copy the Criterion instance within BaseDecisionTree.fit to prevent segfault caused by concurrent accesses.
1 parent 15c2c72 commit 3ff8dc2

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

doc/whats_new/v0.24.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ Changelog
5454
- |Fix| Better contains the CSS provided by :func:`utils.estimator_html_repr`
5555
by giving CSS ids to the html representation. :pr:`19417` by `Thomas Fan`_.
5656

57+
:mod:`sklearn.tree`
58+
.......................
59+
60+
- |Fix| Fix a bug in `fit` of :class:`tree.BaseDecisionTree` that caused
61+
segmentation faults under certain conditions. `fit` now deep copies the
62+
`Criterion` object to prevent shared concurrent accesses.
63+
:pr:`19580` by :user:`Samuel Brice <samdbrice>` and
64+
:user:`Alex Adamson <aadamson>` and
65+
:user:`Wil Yegelwel <wyegelwel>`.
66+
5767
.. _changes_0_24_1:
5868

5969
Version 0.24.1

sklearn/ensemble/tests/test_forest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,3 +1494,23 @@ def test_n_features_deprecation(Estimator):
14941494

14951495
with pytest.warns(FutureWarning, match="n_features_ was deprecated"):
14961496
est.n_features_
1497+
1498+
1499+
@pytest.mark.parametrize('Forest', FOREST_REGRESSORS)
1500+
def test_mse_criterion_object_segfault_smoke_test(Forest):
1501+
# This is a smoke test to ensure that passing a mutable criterion
1502+
# does not cause a segfault when fitting with concurrent threads.
1503+
# Non-regression test for:
1504+
# https://github.com/scikit-learn/scikit-learn/issues/12623
1505+
from sklearn.tree._classes import CRITERIA_REG
1506+
1507+
X = np.random.random((1000, 3))
1508+
y = np.random.random((1000, 1))
1509+
1510+
n_samples, n_outputs = y.shape
1511+
mse_criterion = CRITERIA_REG['mse'](n_outputs, n_samples)
1512+
est = FOREST_REGRESSORS[Forest](
1513+
n_estimators=2, n_jobs=2, criterion=mse_criterion
1514+
)
1515+
1516+
est.fit(X, y)

sklearn/tree/_classes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numbers
1818
import warnings
19+
import copy
1920
from abc import ABCMeta
2021
from abc import abstractmethod
2122
from math import ceil
@@ -349,6 +350,10 @@ def fit(self, X, y, sample_weight=None, check_input=True,
349350
else:
350351
criterion = CRITERIA_REG[self.criterion](self.n_outputs_,
351352
n_samples)
353+
else:
354+
# Make a deepcopy in case the criterion has mutable attributes that
355+
# might be shared and modified concurrently during parallel fitting
356+
criterion = copy.deepcopy(criterion)
352357

353358
SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
354359

0 commit comments

Comments
 (0)
0