8000 FIX fix performance regression in trees with low-cardinality features… · scikit-learn/scikit-learn@84dbab2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 84dbab2

Browse files
lesteveglemaitrethomasjpfan
committed
FIX fix performance regression in trees with low-cardinality features (#23410)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent eb2f71e commit 84dbab2

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

sklearn/tree/_splitter.pyx

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ from ._utils cimport log
2626
from ._utils cimport rand_int
2727
from ._utils cimport rand_uniform
2828
from ._utils cimport RAND_R_MAX
29-
from ..utils._sorting cimport simultaneous_sort
3029

3130
cdef double INFINITY = np.inf
3231

@@ -342,7 +341,7 @@ cdef class BestSplitter(BaseDenseSplitter):
342341
for i in range(start, end):
343342
Xf[i] = self.X[samples[i], current.feature]
344343

345-
simultaneous_sort(&Xf[start], &samples[start], end - start)
344+
sort(&Xf[start], &samples[start], end - start)
346345

347346
if Xf[end - 1] <= Xf[start] + AA0F FEATURE_THRESHOLD:
348347
features[f_j], features[n_total_constants] = features[n_total_constants], features[f_j]
@@ -1161,11 +1160,11 @@ cdef class BestSparseSplitter(BaseSparseSplitter):
11611160
current.feature = features[f_j]
11621161
self.extract_nnz(current.feature, &end_negative, &start_positive,
11631162
&is_samples_sorted)
1164-
11651163
# Sort the positive and negative parts of `Xf`
1166-
simultaneous_sort(&Xf[start], &samples[start], end_negative - start)
1164+
sort(&Xf[start], &samples[start], end_negative - start)
11671165
if start_positive < end:
1168-
simultaneous_sort(&Xf[start_positive], &samples[start_positive], end - start_positive)
1166+
sort(&Xf[start_positive], &samples[start_positive],
1167+
end - start_positive)
11691168

11701169
# Update index_to_samples to take into account the sort
11711170
for p in range(start, end_negative):

0 commit comments

Comments
 (0)
0