@@ -1064,10 +1064,17 @@ cdef class BinaryTree:
1064
1064
1065
1065
def __init__ (self , data ,
1066
1066
leaf_size = 40 , metric = 'minkowski' , sample_weight = None , ** kwargs ):
1067
- self .data_arr = np .asarray (data , dtype = DTYPE , order = 'C' )
1068
- self .data = get_memview_DTYPE_2D (self .data_arr )
1067
+ # validate data
1068
+ if data .size == 0 :
1069
+ raise ValueError ("X is an empty array" )
1070
+
1071
+ if leaf_size < 1 :
1072
+ raise ValueError ("leaf_size must be greater than or equal to 1" )
1069
1073
1074
+ n_samples = data .shape [0 ]
1075
+ n_features = data .shape [1 ]
1070
1076
1077
+ self .data_arr = np .asarray (data , dtype = DTYPE , order = 'C' )
1071
1078
self .leaf_size = leaf_size
1072
1079
self .dist_metric = DistanceMetric .get_metric (metric , ** kwargs )
1073
1080
self .euclidean = (self .dist_metric .__class__ .__name__
@@ -1079,26 +1086,6 @@ cdef class BinaryTree:
1079
1086
'{BinaryTree}' .format (metric = metric ,
1080
1087
** DOC_DICT ))
1081
1088
1082
- # validate data
1083
- if self .data .size == 0 :
1084
- raise ValueError ("X is an empty array" )
1085
-
1086
- if leaf_size < 1 :
1087
- raise ValueError ("leaf_size must be greater than or equal to 1" )
1088
-
1089
- n_samples = self .data .shape [0 ]
1090
- n_features = self .data .shape [1 ]
1091
-
1092
-
1093
- if sample_weight is not None :
1094
- self .sample_weight_arr = np .asarray (sample_weight , dtype = DTYPE , order = 'C' )
1095
- self .sample_weight = get_memview_DTYPE_1D (self .sample_weight_arr )
1096
- self .sum_weight = np .sum (self .sample_weight )
1097
- else :
1098
- self .sample_weight = None
1099
- self .sum_weight = < DTYPE_t > n_samples
1100
-
1101
-
1102
1089
# determine number of levels in the tree, and from this
1103
1090
# the number of nodes in the tree. This results in leaf nodes
1104
1091
# with numbers of points between leaf_size and 2 * leaf_size
@@ -1107,15 +1094,34 @@ cdef class BinaryTree:
1107
1094
1108
1095
# allocate arrays for storage
1109
1096
self .idx_array_arr = np .arange (n_samples , dtype = ITYPE )
1110
- self .idx_array = get_memview_ITYPE_1D (self .idx_array_arr )
1111
-
1112
1097
self .node_data_arr = np .zeros (self .n_nodes , dtype = NodeData )
1113
- self .node_data = get_memview_NodeData_1D (self .node_data_arr )
1098
+
1099
+ self ._update_sample_weight (n_samples , sample_weight )
1100
+ self ._update_memviews ()
1114
1101
1115
1102
# Allocate tree-specific data
1116
1103
allocate_data (self , self .n_nodes , n_features )
1117
1104
self ._recursive_build (0 , 0 , n_samples )
1118
1105
1106
+ def _update_sample_weight (self , n_samples , sample_weight ):
1107
+ if sample_weight is not None :
1108
+ self .sample_weight_arr = np .asarray (
1109
+ sample_weight , dtype = DTYPE , order = 'C' )
1110
+ self .sample_weight = get_memview_DTYPE_1D (
1111
+ self .sample_weight_arr )
1112
+ self .sum_weight = np .sum (self .sample_weight )
1113
+ else :
1114
+ self .sample_weight = None
1115
+ self .sample_weight_arr = np .empty (1 , dtype = DTYPE , order = 'C' )
1116
+ self .sum_weight = < DTYPE_t > n_samples
1117
+
1118
+ def _update_memviews (self ):
1119
+ self .data = get_memview_DTYPE_2D (self .data_arr )
1120
+ self .idx_array = get_memview_ITYPE_1D (self .idx_array_arr )
1121
+ self .node_data = get_memview_NodeData_1D (self .node_data_arr )
1122
+ self .node_bounds = get_memview_DTYPE_3D (self .node_bounds_arr )
1123
+
1124
+
1119
1125
def __reduce__ (self ):
1120
1126
"""
1121
1127
reduce method used for pickling
@@ -1126,6 +1132,13 @@ cdef class BinaryTree:
1126
1132
"""
1127
1133
get state for pickling
1128
1134
"""
1135
+ if self .sample_weight is not None :
1136
+ # pass the numpy array
1137
+ sample_weight_arr = self .sample_weight_arr
1138
+ else :
1139
+ # pass None to avoid confusion with the empty place holder
1140
+ # of size 1 from __cinit__
1141
+ sample_weight_arr = None
1129
1142
return (self .data_arr ,
1130
1143
self .idx_array_arr ,
1131
1144
self .node_data_arr ,
@@ -1138,7 +1151,7 @@ cdef class BinaryTree:
1138
1151
int (self .n_splits ),
1139
1152
int (self .n_calls ),
1140
1153
self .dist_metric ,
1141
- self . sample_weight )
1154
+ sample_weight_arr )
1142
1155
1143
1156
def __setstate__ (self , state ):
1144
1157
"""
@@ -1148,12 +1161,6 @@ cdef class BinaryTree:
1148
1161
self .idx_array_arr = state [1 ]
1149
1162
self .node_data_arr = state [2 ]
1150
1163
self .node_bounds_arr = state [3 ]
1151
-
1152
- self .data = get_memview_DTYPE_2D (self .data_arr )
1153
- self .idx_array = get_memview_ITYPE_1D (self .idx_array_arr )
1154
- self .node_data = get_memview_NodeData_1D (self .node_data_arr )
1155
- self .node_bounds = get_memview_DTYPE_3D (self .node_bounds_arr )
1156
-
1157
1164
self .leaf_size = state [4 ]
1158
1165
self .n_levels = state [5 ]
1159
1166
self .n_nodes = state [6 ]
@@ -1162,9 +1169,13 @@ cdef class BinaryTree:
1162
1169
self .n_splits = state [9 ]
1163
1170
self .n_calls = state [10 ]
1164
1171
self .dist_metric = state [11 ]
1172
+ sample_weight_arr = state [12 ]
1173
+
1165
1174
self .euclidean = (self .dist_metric .__class__ .__name__
1166
1175
== 'EuclideanDistance' )
1167
- self .sample_weight = state [12 ]
1176
+ n_samples = self .data_arr .shape [0 ]
1177
+ self ._update_sample_weight (n_samples , sample_weight_arr )
1178
+ self ._update_memviews ()
1168
1179
1169
1180
def get_tree_stats (self ):
1170
1181
return (self .n_trims , self .n_leaves , self .n_splits )
0 commit comments