@@ -1063,10 +1063,17 @@ cdef class BinaryTree:
1063
1063
1064
1064
def __init__ (self , data ,
1065
1065
leaf_size = 40 , metric = 'minkowski' , sample_weight = None , ** kwargs ):
1066
- self .data_arr = np .asarray (data , dtype = DTYPE , order = 'C' )
1067
- self .data = get_memview_DTYPE_2D (self .data_arr )
1066
+ # validate data
1067
+ if data .size == 0 :
1068
+ raise ValueError ("X is an empty array" )
1069
+
1070
+ if leaf_size < 1 :
1071
+ raise ValueError ("leaf_size must be greater than or equal to 1" )
1068
1072
1073
+ n_samples = data .shape [0 ]
1074
+ n_features = data .shape [1 ]
1069
1075
1076
+ self .data_arr = np .asarray (data , dtype = DTYPE , order = 'C' )
1070
1077
self .leaf_size = leaf_size
1071
1078
self .dist_metric = DistanceMetric .get_metric (metric , ** kwargs )
1072
1079
self .euclidean = (self .dist_metric .__class__ .__name__
@@ -1078,26 +1085,6 @@ cdef class BinaryTree:
1078
1085
'{BinaryTree}' .format (metric = metric ,
1079
1086
** DOC_DICT ))
1080
1087
1081
- # validate data
1082
- if self .data .size == 0 :
1083
- raise ValueError ("X is an empty array" )
1084
-
1085
- if leaf_size < 1 :
1086
- raise ValueError ("leaf_size must be greater than or equal to 1" )
1087
-
1088
- n_samples = self .data .shape [0 ]
1089
- n_features = self .data .shape [1 ]
1090
-
1091
-
1092
- if sample_weight is not None :
1093
- self .sample_weight_arr = np .asarray (sample_weight , dtype = DTYPE , order = 'C' )
1094
- self .sample_weight = get_memview_DTYPE_1D (self .sample_weight_arr )
1095
- self .sum_weight = np .sum (self .sample_weight )
1096
- else :
1097
- self .sample_weight = None
1098
- self .sum_weight = < DTYPE_t > n_samples
1099
-
1100
-
1101
1088
# determine number of levels in the tree, and from this
1102
1089
# the number of nodes in the tree. This results in leaf nodes
1103
1090
# with numbers of points between leaf_size and 2 * leaf_size
@@ -1106,15 +1093,34 @@ cdef class BinaryTree:
1106
1093
1107
1094
# allocate arrays for storage
1108
1095
self .idx_array_arr = np .arange (n_samples , dtype = ITYPE )
1109
- self .idx_array = get_memview_ITYPE_1D (self .idx_array_arr )
1110
-
1111
1096
self .node_data_arr = np .zeros (self .n_nodes , dtype = NodeData )
1112
- self .node_data = get_memview_NodeData_1D (self .node_data_arr )
1097
+
1098
+ self ._update_sample_weight (n_samples , sample_weight )
1099
+ self ._update_memviews ()
1113
1100
1114
1101
# Allocate tree-specific data
1115
1102
allocate_data (self , self .n_nodes , n_features )
1116
1103
self ._recursive_build (0 , 0 , n_samples )
1117
1104
1105
+ def _update_sample_weight (self , n_samples , sample_weight ):
1106
+ if sample_weight is not None :
1107
+ self .sample_weight_arr = np .asarray (
1108
+ sample_weight , dtype = DTYPE , order = 'C' )
1109
+ self .sample_weight = get_memview_DTYPE_1D (
1110
+ self .sample_weight_arr )
1111
+ self .sum_weight = np .sum (self .sample_weight )
1112
+ else :
1113
+ self .sample_weight = None
1114
+ self .sample_weight_arr = np .empty (1 , dtype = DTYPE , order = 'C' )
1115
+ self .sum_weight = < DTYPE_t > n_samples
1116
+
1117
+ def _update_memviews (self ):
1118
+ self .data = get_memview_DTYPE_2D (self .data_arr )
1119
+ self .idx_array = get_memview_ITYPE_1D (self .idx_array_arr )
1120
+ self .node_data = get_memview_NodeData_1D (self .node_data_arr )
1121
+ self .node_bounds = get_memview_DTYPE_3D (self .node_bounds_arr )
1122
+
1123
+
1118
1124
def __reduce__ (self ):
1119
1125
"""
1120
1126
reduce method used for pickling
@@ -1125,6 +1131,13 @@ cdef class BinaryTree:
1125
1131
"""
1126
1132
get state for pickling
1127
1133
"""
1134
+ if self .sample_weight is not None :
1135
+ # pass the numpy array
1136
+ sample_weight_arr = self .sample_weight_arr
1137
+ else :
1138
+ # pass None to avoid confusion with the empty place holder
1139
+ # of size 1 from __cinit__
1140
+ sample_weight_arr = None
1128
1141
return (self .data_arr ,
1129
1142
self .idx_array_arr ,
1130
1143
self .node_data_arr ,
@@ -1137,7 +1150,7 @@ cdef class BinaryTree:
1137
1150
int (self .n_splits ),
1138
1151
int (self .n_calls ),
1139
1152
self .dist_metric ,
1140
- self . sample_weight )
1153
+ sample_weight_arr )
1141
1154
1142
1155
def __setstate__ (self , state ):
1143
1156
"""
@@ -1147,12 +1160,6 @@ cdef class BinaryTree:
1147
1160
self .idx_array_arr = state [1 ]
1148
1161
self .node_data_arr = state [2 ]
1149
1162
self .node_bounds_arr = state [3 ]
1150
-
1151
- self .data = get_memview_DTYPE_2D (self .data_arr )
1152
- self .idx_array = get_memview_ITYPE_1D (self .idx_array_arr )
1153
- self .node_data = get_memview_NodeData_1D (self .node_data_arr )
1154
- self .node_bounds = get_memview_DTYPE_3D (self .node_bounds_arr )
1155
-
1156
1163
self .leaf_size = state [4 ]
1157
1164
self .n_levels = state[5 ]
1158
1165
self .n_nodes = state [6 ]
@@ -1161,9 +1168,13 @@ cdef class BinaryTree:
1161
1168
self .n_splits = state [9 ]
1162
1169
self .n_calls = state [10 ]
1163
1170
self .dist_metric = state [11 ]
1171
+ sample_weight_arr = state [12 ]
1172
+
1164
1173
self .euclidean = (self .dist_metric .__class__ .__name__
1165
1174
== 'EuclideanDistance' )
1166
- self .sample_weight = state [12 ]
1175
+ n_samples = self .data_arr .shape [0 ]
1176
+ self ._update_sample_weight (n_samples , sample_weight_arr )
1177
+ self ._update_memviews ()
1167
1178
1168
1179
def get_tree_stats (self ):
1169
1180
return (self .n_trims , self .n_leaves , self .n_splits )
0 commit comments