8000 BUG: Use ``Py_ssize_t`` in a few more places for strides. Add the c f… · nullnotfound/scikit-learn@0f1950c · GitHub
[go: up one dir, main page]

Skip to content

Commit 0f1950c

Browse files
committed
BUG: Use Py_ssize_t in a few more places for strides. Add the c file again.
1 parent 4db660e commit 0f1950c

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

sklearn/tree/_tree.c

Lines changed: 19 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sklearn/tree/_tree.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,9 @@ cdef class Tree:
458458
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
459459
cdef DOUBLE_t w = 1.0
460460

461-
cdef Py_ssize_t X_stride = <int> X.strides[1] / <int> X.itemsize
462-
cdef Py_ssize_t X_argsorted_stride = <int> X_argsorted.strides[1] / <int> X_argsorted.itemsize
463-
cdef Py_ssize_t y_stride = <int> y.strides[0] / <int> y.itemsize
461+
cdef Py_ssize_t X_stride = <Py_ssize_t> X.strides[1] / <int> X.itemsize
462+
cdef Py_ssize_t X_argsorted_stride = <Py_ssize_t> X_argsorted.strides[1] / <int> X_argsorted.itemsize
463+
cdef Py_ssize_t y_stride = <Py_ssize_t> y.strides[0] / <int> y.itemsize
464464

465465
cdef int n_total_samples = y.shape[0]
466466
cdef int feature
@@ -532,14 +532,14 @@ cdef class Tree:
532532
n_total_samples = n_node_samples
533533

534534
X_ptr = <DTYPE_t*> X.data
535-
X_stride = <int> X.strides[1] / <int> X.itemsize
535+
X_stride = <Py_ssize_t> X.strides[1] / <int> X.itemsize
536536
sample_mask_ptr = <BOOL_t*> sample_mask.data
537537

538538
# !! No need to update the other variables
539539
# X_argsorted_ptr = <int*> X_argsorted.data
540540
# y_ptr = <DOUBLE_t*> y.data
541-
# X_argsorted_stride = <int> X_argsorted.strides[1] / <int> X_argsorted.itemsize
542-
# y_stride = <int> y.strides[0] / <int> y.itemsize
541+
# X_argsorted_stride = <Py_ssize_t> X_argsorted.strides[1] / <int> X_argsorted.itemsize
542+
# y_stride = <Py_ssize_t> y.strides[0] / <int> y.itemsize
543543

544544
# Split
545545
X_ptr = X_ptr + feature * X_stride

0 commit comments

Comments
 (0)
0