8000 FIX Uses log2 in tree building (#30557) · scikit-learn/scikit-learn@1a2bcb5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1a2bcb5

Browse files
thomasjpfanjeremiedbb
authored andcommitted
FIX Uses log2 in tree building (#30557)
1 parent fad237e commit 1a2bcb5

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Use `log2` instead of `ln` for building trees to maintain behavior of previous
2+
versions. By `Thomas Fan`_

sklearn/tree/_partitioner.pyx

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ and sparse data stored in a Compressed Sparse Column (CSC) format.
1111
# SPDX-License-Identifier: BSD-3-Clause
1212

1313
from cython cimport final
14-
from libc.math cimport isnan, log
14+
from libc.math cimport isnan, log2
1515
from libc.stdlib cimport qsort
1616
from libc.string cimport memcpy
1717

@@ -503,8 +503,8 @@ cdef class SparsePartitioner:
503503
# O(n_samples * log(n_indices)) is the running time of binary
504504
# search and O(n_indices) is the running time of index_to_samples
505505
# approach.
506-
if ((1 - self.is_samples_sorted) * n_samples * log(n_samples) +
507-
n_samples * log(n_indices) < EXTRACT_NNZ_SWITCH * n_indices):
506+
if ((1 - self.is_samples_sorted) * n_samples * log2(n_samples) +
507+
n_samples * log2(n_indices) < EXTRACT_NNZ_SWITCH * n_indices):
508508
extract_nnz_binary_search(X_indices, X_data,
509509
indptr_start, indptr_end,
510510
samples, self.start, self.end,
@@ -702,12 +702,17 @@ cdef inline void shift_missing_values_to_left_if_required(
702702
best.pos += best.n_missing
703703

704704

705+
def _py_sort(float32_t[::1] feature_values, intp_t[::1] samples, intp_t n):
706+
"""Used for testing sort."""
707+
sort(&feature_values[0], &samples[0], n)
708+
709+
705710
# Sort n-element arrays pointed to by feature_values and samples, simultaneously,
706711
# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997).
707712
cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil:
708713
if n == 0:
709714
return
710-
cdef intp_t maxd = 2 * <intp_t>log(n)
715+
cdef intp_t maxd = 2 * <intp_t>log2(n)
711716
introsort(feature_values, samples, n, maxd)
712717

713718

sklearn/tree/tests/test_tree.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
DENSE_SPLITTERS,
3737
SPARSE_SPLITTERS,
3838
)
39+
from sklearn.tree._partitioner import _py_sort
3940
from sklearn.tree._tree import (
4041
NODE_DTYPE,
4142
TREE_LEAF,
@@ -2814,3 +2815,25 @@ def test_build_pruned_tree_infinite_loop():
28142815
ValueError, match="Node has reached a leaf in the original tree"
28152816
):
28162817
_build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree)
2818+
2819+
2820+
def test_sort_log2_build():
2821+
"""Non-regression test for gh-30554.
2822+
2823+
Using log2 and log in sort correctly sorts feature_values, but the tie breaking is
2824+
different which can results in placing samples in a different order.
2825+
"""
2826+
rng = np.random.default_rng(75)
2827+
some = rng.normal(loc=0.0, scale=10.0, size=10).astype(np.float32)
2828+
feature_values = np.concatenate([some] * 5)
2829+
samples = np.arange(50)
2830+
_py_sort(feature_values, samples, 50)
2831+
# fmt: off
2832+
# no black reformatting for this specific array
2833+
expected_samples = [
2834+
0, 40, 30, 20, 10, 29, 39, 19, 49, 9, 45, 15, 35, 5, 25, 11, 31,
2835+
41, 1, 21, 22, 12, 2, 42, 32, 23, 13, 43, 3, 33, 6, 36, 46, 16,
2836+
26, 4, 14, 24, 34, 44, 27, 47, 7, 37, 17, 8, 38, 48, 28, 18
2837+
]
2838+
# fmt: on
2839+
assert_array_equal(samples, expected_samples)

0 commit comments

Comments
 (0)
0