8000 FIX fix performance regression in trees with low-cardinality features… · glemaitre/scikit-learn@007ea89 · GitHub
[go: up one dir, main page]

Skip to content

Commit 007ea89

Browse files
lesteveglemaitrethomasjpfan
committed
FIX fix performance regression in trees with low-cardinality features (scikit-learn#23410)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent a656d8a commit 007ea89

File tree

2 files changed

+135
-11
lines changed

2 files changed

+135
-11
lines changed

doc/whats_new/v1.1.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,19 @@ Changelog
3939
classifier that always predicts the positive class: recall=100% and
4040
precision=class balance.
4141
:pr:`23214` by :user:`Stéphane Collot <stephanecollot>` and :user:`Max Baak <mbaak>`.
42-
42+
43+
:mod:`sklearn.tree`
44+
...................
45+
46+
- |Fix| Fixes performance regression with low cardinality features for
47+
:class:`tree.DecisionTreeClassifier`,
48+
:class:`tree.DecisionTreeRegressor`,
49+
:class:`ensemble.RandomForestClassifier`,
50+
:class:`ensemble.RandomForestRegressor`,
51+
:class:`ensemble.GradientBoostingClassifier`, and
52+
:class:`ensemble.GradientBoostingRegressor`.
53+
:pr:`23410` by :user:`Loïc Estève <lesteve>`
54+
4355
:mod:`sklearn.utils`
4456
....................
4557

sklearn/tree/_splitter.pyx

Lines changed: 122 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ from ._utils cimport rand_int
2929
from ._utils cimport rand_uniform
3030
from ._utils cimport RAND_R_MAX
3131
from ._utils cimport safe_realloc
32-
from ..utils._sorting cimport simultaneous_sort
3332

3433
cdef double INFINITY = np.inf
3534

@@ -359,7 +358,7 @@ cdef class BestSplitter(BaseDenseSplitter):
359358
for i in range(start, end):
360359
Xf[i] = self.X[samples[i], current.feature]
361360

362-
simultaneous_sort(Xf + start, samples + start, end - start)
361+
sort(Xf + start, samples + start, end - start)
363362

364363
if Xf[end - 1] <= Xf[start] + FEATURE_THRESHOLD:
365364
features[f_j], features[n_total_constants] = features[n_total_constants], features[f_j]
@@ -452,6 +451,120 @@ cdef class BestSplitter(BaseDenseSplitter):
452451
return 0
453452

454453

454+
# Sort n-element arrays pointed to by Xf and samples, simultaneously,
455+
# by the values in Xf. Algorithm: Introsort (Musser, SP&E, 1997).
456+
cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil:
457+
if n == 0:
458+
return
459+
cdef int maxd = 2 * <int>log(n)
460+
introsort(Xf, samples, n, maxd)
461+
462+
463+
cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples,
464+
SIZE_t i, SIZE_t j) nogil:
465+
# Helper for sort
466+
Xf[i], Xf[j] = Xf[j], Xf[i]
467+
samples[i], samples[j] = samples[j], samples[i]
468+
469+
470+
cdef inline DTYPE_t median3(DTYPE_t* Xf, SIZE_t n) nogil:
471+
# Median of three pivot selection, after Bentley and McIlroy (1993).
472+
# Engineering a sort function. SP&E. Requires 8/3 comparisons on average.
473+
cdef DTYPE_t a = Xf[0], b = Xf[n / 2], c = Xf[n - 1]
474+
if a < b:
475+
if b < c:
476+
return b
477+
elif a < c:
478+
return c
479+
else:
480+
return a
481+
elif b < c:
482+
if a < c:
483+
return a
484+
else:
485+
return c
486+
else:
487+
return b
488+
489+
490+
# Introsort with median of 3 pivot selection and 3-way partition function
491+
# (robust to repeated elements, e.g. lots of zero features).
492+
cdef void introsort(DTYPE_t* Xf, SIZE_t *samples,
493+
SIZE_t n, int maxd) nogil:
494+
cdef DTYPE_t pivot
495+
cdef SIZE_t i, l, r
496+
497+
while n > 1:
498+
if maxd <= 0: # max depth limit exceeded ("gone quadratic")
499+
heapsort(Xf, samples, n)
500+
return
501+
maxd -= 1
502+
503+
pivot = median3(Xf, n)
504+
505+
# Three-way partition.
506+
i = l = 0
507+
r = n
508+
while i < r:
509+
if Xf[i] < pivot:
510+
swap(Xf, samples, i, l)
511+
i += 1
512+
l += 1
513+
elif Xf[i] > pivot:
514+
r -= 1
515+
swap(Xf, samples, i, r)
516+
else:
517+
i += 1
518+
519+
introsort(Xf, samples, l, maxd)
520+
Xf += r
521+
samples += r
522+
n -= r
523+
524+
525+
cdef inline void sift_down(DTYPE_t* Xf, SIZE_t* samples,
526+
SIZE_t start, SIZE_t end) nogil:
527+
# Restore heap order in Xf[start:end] by moving the max element to start.
528+
cdef SIZE_t child, maxind, root
529+
530+
root = start
531+
while True:
532+
F438 child = root * 2 + 1
533+
534+
# find max of root, left child, right child
535+
maxind = root
536+
if child < end and Xf[maxind] < Xf[child]:
537+
maxind = child
538+
if child + 1 < end and Xf[maxind] < Xf[child + 1]:
539+
maxind = child + 1
540+
541+
if maxind == root:
542+
break
543+
else:
544+
swap(Xf, samples, root, maxind)
545+
root = maxind
546+
547+
548+
cdef void heapsort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil:
549+
cdef SIZE_t start, end
550+
551+
# heapify
552+
start = (n - 2) / 2
553+
end = n
554+
while True:
555+
sift_down(Xf, samples, start, end)
556+
if start == 0:
557+
break
558+
start -= 1
559+
560+
# sort by shrinking the heap, putting the max element immediately after it
561+
end = n - 1
562+
while end > 0:
563+
swap(Xf, samples, 0, end)
564+
sift_down(Xf, samples, 0, end)
565+
end = end - 1
566+
567+
455568
cdef class RandomSplitter(BaseDenseSplitter):
456569
"""Splitter for finding the best random split."""
457570
def __reduce__(self):
@@ -1077,14 +1190,13 @@ cdef class BestSparseSplitter(BaseSparseSplitter):
10771190
f_j += n_found_constants
10781191
# f_j in the interval [n_total_constants, f_i[
10791192

1080-
current.feature = features[f_j]
1081-
self.extract_nnz(current.feature,
1082-
&end_negative, &start_positive,
1083-
&is_samples_sorted)
1084-
1085-
# Sort the positive and negative parts of `Xf`
1086-
simultaneous_sort(Xf + start, samples + start, end_negative - start)
1087-
simultaneous_sort(Xf + start_positive, samples + start_positive,
1193+
current.feature = features[f_j]
1194+
self.extract_nnz(current.feature, &end_negative, &start_positive,
1195+
&is_samples_sorted)
1196+
# Sort the positive and negative parts of `Xf`
1197+
sort(Xf + start, samples + start, end_negative - start)
1198+
if start_positive < end:
1199+
sort(Xf + start_positive, samples + start_positive,
10881200
end - start_positive)
10891201

10901202
# Update index_to_samples to take into account the sort

0 commit comments

Comments
 (0)
0