8000 BUG: Use Py_ssize_t to index into numpy arrays to help Python handle … · nullnotfound/scikit-learn@6a47c6f · GitHub
[go: up one dir, main page]

Skip to content

Commit 6a47c6f

Browse files
committed
BUG: Use Py_ssize_t to index into numpy arrays to help Python handle big data.
Indent a few copy/pasted function declarations for consistency. Fixes scikit-learn#1466.
1 parent 46292a1 commit 6a47c6f

File tree

2 files changed

+51
-51
lines changed

2 files changed

+51
-51
lines changed

sklearn/tree/_tree.pxd

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ cdef class Criterion:
2626
cdef double weighted_n_right
2727

2828
# Methods
29-
cdef void init(self, DOUBLE_t* y, int y_stride,
29+
cdef void init(self, DOUBLE_t* y, Py_ssize_t y_stride,
3030
DOUBLE_t* sample_weight,
3131
BOOL_t* sample_mask,
3232
int n_samples,
@@ -37,7 +37,7 @@ cdef class Criterion:
3737

3838
cdef bool update(self, int a,
3939
int b,
40-
DOUBLE_t* y, int y_stride,
40+
DOUBLE_t* y, Py_ssize_t y_stride,
4141
int* X_argsorted_i,
4242
DOUBLE_t* sample_weight,
4343
BOOL_t* sample_mask)
@@ -58,7 +58,7 @@ cdef class Tree:
5858
cdef public int n_outputs
5959

6060
cdef public int max_n_classes
61-
cdef public int value_stride
61+
cdef public Py_ssize_t value_stride
6262

6363
# Parameters
6464
cdef public Criterion criterion
@@ -113,9 +113,9 @@ cdef class Tree:
113113
cdef int add_leaf(self, int parent, int is_left_child, double* value,
114114
double error, int n_samples)
115115

116-
cdef void find_split(self, DTYPE_t* X_ptr, int X_stride,
117-
int* X_argsorted_ptr, int X_argsorted_stride,
118-
DOUBLE_t* y_ptr, int y_stride,
116+
cdef void find_split(self, DTYPE_t* X_ptr, Py_ssize_t X_stride,
117+
int* X_argsorted_ptr, Py_ssize_t X_argsorted_stride,
118+
DOUBLE_t* y_ptr, Py_ssize_t y_stride,
119119
DOUBLE_t* sample_weight_ptr,
120120
BOOL_t* sample_mask_ptr,
121121
int n_node_samples,
@@ -126,9 +126,9 @@ cdef class Tree:
126126
double* _best_error,
127127
double* _initial_error)
128128

129-
cdef void find_best_split(self, DTYPE_t* X_ptr, int X_stride,
130-
int* X_argsorted_ptr, int X_argsorted_stride,
131-
DOUBLE_t* y_ptr, int y_stride,
129+
cdef void find_best_split(self, DTYPE_t* X_ptr, Py_ssize_t X_stride,
130+
int* X_argsorted_ptr, Py_ssize_t X_argsorted_stride,
131+
DOUBLE_t* y_ptr, Py_ssize_t y_stride,
132132
DOUBLE_t* sample_weight_ptr,
133133
BOOL_t* sample_mask_ptr,
134134
int n_node_samples,
@@ -137,9 +137,9 @@ cdef class Tree:
137137
double* _best_t, double* _best_error,
138138
double* _initial_error)
139139

140-
cdef void find_random_split(self, DTYPE_t* X_ptr, int X_stride,
141-
int* X_argsorted_ptr, int X_argsorted_stride,
142-
DOUBLE_t* y_ptr, int y_stride,
140+
cdef void find_random_split(self, DTYPE_t* X_ptr, Py_ssize_t X_stride,
141+
int* X_argsorted_ptr, Py_ssize_t X_argsorted_stride,
142+
DOUBLE_t* y_ptr, Py_ssize_t y_stride,
143143
DOUBLE_t* sample_weight_ptr,
144144
BOOL_t* sample_mask_ptr,
145145
int n_node_samples,

sklearn/tree/_tree.pyx

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ cdef class Tree:
131131
# cdef public int n_outputs
132132

133133
# cdef public int max_n_classes
134-
# cdef public int value_stride
134+
# cdef public Py_ssize_t value_stride
135135

136136
# # Parameters
137137
# cdef public Criterion criterion
@@ -458,9 +458,9 @@ cdef class Tree:
458458
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
459459
cdef DOUBLE_t w = 1.0
460460

461-
cdef int X_stride = <int> X.strides[1] / <int> X.itemsize
462-
cdef int X_argsorted_stride = <int> X_argsorted.strides[1] / <int> X_argsorted.itemsize
463-
cdef int y_stride = <int> y.strides[0] / <int> y.itemsize
461+
cdef Py_ssize_t X_stride = <int> X.strides[1] / <int> X.itemsize
462+
cdef Py_ssize_t X_argsorted_stride = <int> X_argsorted.strides[1] / <int> X_argsorted.itemsize
463+
cdef Py_ssize_t y_stride = <int> y.strides[0] / <int> y.itemsize
464464

465465
cdef int n_total_samples = y.shape[0]
466466
cdef int feature
@@ -657,9 +657,9 @@ cdef class Tree:
657657

658658
return node_id
659659

660-
cdef void find_split(self, DTYPE_t* X_ptr, int X_stride,
661-
int* X_argsorted_ptr, int X_argsorted_stride,
662-
DOUBLE_t* y_ptr, int y_stride,
660+
cdef void find_split(self, DTYPE_t* X_ptr, Py_ssize_t X_stride,
661+
int* X_argsorted_ptr, Py_ssize_t X_argsorted_stride,
662+
DOUBLE_t* y_ptr, Py_ssize_t y_stride,
663663
DOUBLE_t* sample_weight_ptr,
664664
BOOL_t* sample_mask_ptr,
665665
int n_node_samples,
@@ -691,9 +691,9 @@ cdef class Tree:
691691
n_total_samples, _best_i, _best_t,
692692
_best_error, _initial_error)
693693

694-
cdef void find_best_split(self, DTYPE_t* X_ptr, int X_stride,
695-
int* X_argsorted_ptr, int X_argsorted_stride,
696-
DOUBLE_t* y_ptr, int y_stride,
694+
cdef void find_best_split(self, DTYPE_t* X_ptr, Py_ssize_t X_stride,
695+
int* X_argsorted_ptr, Py_ssize_t X_argsorted_stride,
696+
DOUBLE_t* y_ptr, Py_ssize_t y_stride,
697697
DOUBLE_t* sample_weight_ptr,
698698
BOOL_t* sample_mask_ptr,
699699
int n_node_samples,
@@ -822,9 +822,9 @@ cdef class Tree:
822822
_best_error[0] = best_error
823823
_initial_error[0] = initial_error
824824

825-
cdef void find_random_split(self, DTYPE_t* X_ptr, int X_stride,
826-
int* X_argsorted_ptr, int X_argsorted_stride,
827-
DOUBLE_t* y_ptr, int y_stride,
825+
cdef void find_random_split(self, DTYPE_t* X_ptr, Py_ssize_t X_stride,
826+
int* X_argsorted_ptr, Py_ssize_t X_argsorted_stride,
827+
DOUBLE_t* y_ptr, Py_ssize_t y_stride,
828828
DOUBLE_t* sample_weight_ptr,
829829
BOOL_t* sample_mask_ptr,
830830
int n_node_samples,
@@ -1034,12 +1034,12 @@ cdef class Tree:
10341034
cdef class Criterion:
10351035
"""Interface for splitting criteria (regression and classification)."""
10361036

1037-
cdef void init(self, DOUBLE_t* y, int y_stride,
1038-
DOUBLE_t* sample_weight,
1039-
BOOL_t* sample_mask,
1040-
int n_samples,
1041-
double weighted_n_samples,
1042-
int n_total_samples):
1037+
cdef void init(self, DOUBLE_t* y, Py_ssize_t y_stride,
1038+
DOUBLE_t* sample_weight,
1039+
BOOL_t* sample_mask,
1040+
int n_samples,
1041+
double weighted_n_samples,
1042+
int n_total_samples):
10431043
"""Initialise the criterion."""
10441044
pass
10451045

@@ -1048,10 +1048,10 @@ cdef class Criterion:
10481048
pass
10491049

10501050
cdef bool update(self, int a, int b,
1051-
DOUBLE_t* y, int y_stride,
1052-
int* X_argsorted_i,
1053-
DOUBLE_t* sample_weight,
1054-
BOOL_t* sample_mask):
1051+
DOUBLE_t* y, Py_ssize_t y_stride,
1052+
int* X_argsorted_i,
1053+
DOUBLE_t* sample_weight,
1054+
BOOL_t* sample_mask):
10551055
"""Update the criteria for each value in interval [a,b) (where a and b
10561056
are indices in `X_argsorted_i`)."""
10571057
pass
@@ -1083,7 +1083,7 @@ cdef class ClassificationCriterion(Criterion):
10831083
weighted_n_samples : double
10841084
The weighted number of samples.
10851085
1086-
label_count_stride : int
1086+
label_count_stride : Py_ssize_t
10871087
The stride between outputs in label_count_* arrays.
10881088
10891089
label_count_left : double*
@@ -1118,7 +1118,7 @@ cdef class ClassificationCriterion(Criterion):
11181118
"""
11191119
cdef int* n_classes
11201120

1121-
cdef int label_count_stride
1121+
cdef Py_ssize_t label_count_stride
11221122
cdef double* label_count_left
11231123
cdef double* label_count_right
11241124
cdef double* label_count_init
@@ -1139,7 +1139,7 @@ cdef class ClassificationCriterion(Criterion):
11391139
if self.n_classes == NULL:
11401140
raise MemoryError()
11411141

1142-
cdef int label_count_stride = -1
1142+
cdef Py_ssize_t label_count_stride = -1
11431143

11441144
for k from 0 <= k < n_outputs:
11451145
self.n_classes[k] = n_classes[k]
@@ -1183,7 +1183,7 @@ cdef class ClassificationCriterion(Criterion):
11831183
def __setstate__(self, d):
11841184
pass
11851185

1186-
cdef void init(self, DOUBLE_t* y, int y_stride,
1186+
cdef void init(self, DOUBLE_t* y, Py_ssize_t y_stride,
11871187
DOUBLE_t* sample_weight,
11881188
BOOL_t* sample_mask,
11891189
int n_samples,
@@ -1192,7 +1192,7 @@ cdef class ClassificationCriterion(Criterion):
11921192
"""Initialise the criterion."""
11931193
cdef int n_outputs = self.n_outputs
11941194
cdef int* n_classes = self.n_classes
1195-
cdef int label_co 741A unt_stride = self.label_count_stride
1195+
cdef Py_ssize_t label_count_stride = self.label_count_stride
11961196
cdef double* label_count_init = self.label_count_init
11971197

11981198
cdef int k = 0
@@ -1223,7 +1223,7 @@ cdef class ClassificationCriterion(Criterion):
12231223
"""Reset the criterion for a new feature index."""
12241224
cdef int n_outputs = self.n_outputs
12251225
cdef int* n_classes = self.n_classes
1226-
cdef int label_count_stride = self.label_count_stride
1226+
cdef Py_ssize_t label_count_stride = self.label_count_stride
12271227
cdef double* label_count_init = self.label_count_init
12281228
cdef double* label_count_left = self.label_count_left
12291229
cdef double* label_count_right = self.label_count_right
@@ -1244,15 +1244,15 @@ cdef class ClassificationCriterion(Criterion):
12441244
label_count_right[k * label_count_stride + c] = label_count_init[k * label_count_stride + c]
12451245

12461246
cdef bool update(self, int a, int b,
1247-
DOUBLE_t* y, int y_stride,
1247+
DOUBLE_t* y, Py_ssize_t y_stride,
12481248
int* X_argsorted_i,
12491249
DOUBLE_t* sample_weight,
12501250
BOOL_t* sample_mask):
12511251
"""Update the criteria for each value in interval [a,b) (where a and b
12521252
are indices in `X_argsorted_i`)."""
12531253
cdef int n_outputs = self.n_outputs
12541254
cdef int* n_classes = self.n_classes
1255-
cdef int label_count_stride = self.label_count_stride
1255+
cdef Py_ssize_t label_count_stride = self.label_count_stride
12561256
cdef double* label_count_left = self.label_count_left
12571257
cdef double* label_count_right = self.label_count_right
12581258
cdef int n_left = self.n_left
@@ -1310,7 +1310,7 @@ cdef class ClassificationCriterion(Criterion):
13101310
before)."""
13111311
cdef int n_outputs = self.n_outputs
13121312
cdef int* n_classes = self.n_classes
1313-
cdef int label_count_stride = self.label_count_stride
1313+
cdef Py_ssize_t label_count_stride = self.label_count_stride
13141314
cdef double* label_count_init = self.label_count_init
13151315

13161316
cdef int k, c
@@ -1342,7 +1342,7 @@ cdef class Gini(ClassificationCriterion):
13421342
cdef double n_samples = self.weighted_n_samples
13431343
cdef int n_outputs = self.n_outputs
13441344
cdef int* n_classes = self.n_classes
1345-
cdef int label_count_stride = self.label_count_stride
1345+
cdef Py_ssize_t label_count_stride = self.label_count_stride
13461346
cdef double* label_count_left = self.label_count_left
13471347
cdef double* label_count_right = self.label_count_right
13481348
cdef double n_left = self.weighted_n_left
@@ -1404,7 +1404,7 @@ cdef class Entropy(ClassificationCriterion):
14041404
cdef double n_samples = self.weighted_n_samples
14051405
cdef int n_outputs = self.n_outputs
14061406
cdef int* n_classes = self.n_classes
1407-
cdef int label_count_stride = self.label_count_stride
1407+
cdef Py_ssize_t label_count_stride = self.label_count_stride
14081408
cdef double* label_count_left = self.label_count_left
14091409
cdef double* label_count_right = self.label_count_right
14101410
cdef double n_left = self.weighted_n_left
@@ -1564,7 +1564,7 @@ cdef class RegressionCriterion(Criterion):
15641564
def __setstate__(self, d):
B568 15651565
pass
15661566

1567-
cdef void init(self, DOUBLE_t* y, int y_stride,
1567+
cdef void init(self, DOUBLE_t* y, Py_ssize_t y_stride,
15681568
DOUBLE_t* sample_weight,
15691569
BOOL_t* sample_mask,
15701570
int n_samples,
@@ -1654,10 +1654,10 @@ cdef class RegressionCriterion(Criterion):
16541654
weighted_n_samples * (mean_right[k] * mean_right[k]))
16551655

16561656
cdef bool update(self, int a, int b,
1657-
DOUBLE_t* y, int y_stride,
1658-
int* X_argsorted_i,
1659-
DOUBLE_t* sample_weight,
1660-
BOOL_t* sample_mask):
1657+
DOUBLE_t* y, Py_ssize_t y_stride,
1658+
int* X_argsorted_i,
1659+
DOUBLE_t* sample_weight,
1660+
BOOL_t* sample_mask):
16611661
"""Update the criteria for each value in interval [a,b) (where a and b
16621662
are indices in `X_argsorted_i`)."""
16631663
cdef double* mean_left = self.mean_left

0 commit comments

Comments
 (0)
0