8000 fixed the oob_score_ issue, simplified the self.value accesses · scikit-learn/scikit-learn@d198f20 · GitHub
[go: up one dir, main page]

Skip to content

Commit d198f20

Browse files
committed
fixed the oob_score_ issue, simplified the self.value accesses
1 parent c61c8dc commit d198f20

File tree

2 files changed

+38
-40
lines changed

2 files changed

+38
-40
lines changed

sklearn/ensemble/_forest.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -708,8 +708,11 @@ def _compute_unbiased_feature_importance_and_oob_predictions_per_tree(
708708
method=method,
709709
)
710710
)
711+
# If classification, turn the predict proba into
712+
# the one-hot encoded majority class
711713
oob_pred[oob_indices, :, :] += y_pred
712714
n_oob_pred[oob_indices, :] += 1
715+
# print([[oob_pred[idx, :, :], self._get_oob_predictions(tree, X[oob_indices])[i]] for i,idx in enumerate(oob_indices)])
713716
return (importances, oob_pred, n_oob_pred)
714717

715718
def _compute_unbiased_feature_importance_and_oob_predictions(
@@ -761,7 +764,6 @@ def _compute_unbiased_feature_importance_and_oob_predictions(
761764

762765
if not importances.any():
763766
return np.zeros(self.n_features_in_, dtype=np.float64), oob_pred
764-
765767
return importances / importances.sum(), oob_pred
766768

767769
def _get_estimators_indices(self):
@@ -901,12 +903,12 @@ def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
901903
if scoring_function is None:
902904
scoring_function = accuracy_score
903905

904-
ufi_feature_importances, _ = (
906+
ufi_feature_importances, self.oob_decision_function_ = (
905907
self._compute_unbiased_feature_importance_and_oob_predictions(
906908
X, y, method="ufi"
907909
)
908910
)
909-
mdi_oob_feature_importances, self.oob_decision_function_ = (
911+
mdi_oob_feature_importances, _ = (
910912
self._compute_unbiased_feature_importance_and_oob_predictions(
911913
X, y, method="mdi_oob"
912914
)
@@ -1244,19 +1246,19 @@ def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
12441246
if scoring_function is None:
12451247
scoring_function = r2_score
12461248

1247-
mdi_oob_feature_importance, self.oob_prediction_ = (
1249+
ufi_feature_importances, self.oob_prediction_ = (
12481250
self._compute_unbiased_feature_importance_and_oob_predictions(
1249-
X, y, method="mdi_oob"
1251+
X, y, method="ufi"
12501252
)
12511253
)
1252-
ufi_feature_importances, _ = (
1254+
mdi_oob_feature_importances, _ = (
12531255
self._compute_unbiased_feature_importance_and_oob_predictions(
1254-
X, y, method="ufi"
1256+
X, y, method="mdi_oob"
12551257
)
12561258
)
12571259
if self.criterion == "squared_error":
12581260
self._ufi_feature_importances = ufi_feature_importances
1259-
self._mdi_oob_feature_importances = mdi_oob_feature_importance
1261+
self._mdi_oob_feature_importances = mdi_oob_feature_importances
12601262

12611263
if self.oob_prediction_.shape[-1] == 1:
12621264
# drop the n_outputs axis if there is a single output

sklearn/tree/_tree.pyx

Lines changed: 28 additions & 32 deletions
< 579F /tr>
Original file line numberDiff line numberDiff line change
@@ -1292,7 +1292,7 @@ cdef class Tree:
12921292
cdef intp_t max_n_classes = self.max_n_classes
12931293
cdef int k, c, node_idx, sample_idx = 0
12941294
cdef int32_t[:, ::1] count_oob_values = np.zeros((node_count, n_outputs), dtype=np.int32)
1295-
cdef float64_t* value_at_node = self.value + node_idx * n_outputs * max_n_classes
1295+
cdef int node_value_idx = -1
12961296

12971297
cdef Node* node
12981298

@@ -1304,7 +1304,6 @@ cdef class Tree:
13041304
# root node
13051305
node = self.nodes
13061306
node_idx = 0
1307-
value_at_node = self.value + node_idx * n_outputs * max_n_classes
13081307
has_oob_sample[node_idx] = 1
13091308
for k in range(n_outputs):
13101309
if n_classes[k] > 1:
@@ -1315,7 +1314,8 @@ cdef class Tree:
13151314
count_oob_values[node_idx, k] += 1
13161315
else:
13171316
if method == "ufi":
1318-
oob_node_values[node_idx, 0, k] += (y_test[k, sample_idx] - value_at_node[k]) ** 2.0
1317+
node_value_idx = node_idx * self.value_stride + k * max_n_classes
1318+
oob_node_values[node_idx, 0, k] += (y_test[k, sample_idx] - self.value[node_value_idx]) ** 2.0
13191319
else:
13201320
oob_node_values[node_idx, 0, k] += y_test[k, sample_idx]
13211321
count_oob_values[node_idx, k] += 1
@@ -1326,7 +1326,6 @@ cdef class Tree:
13261326
node_idx = node.left_child
13271327
else:
13281328
node_idx = node.right_child
1329-
value_at_node = self.value + node_idx * n_outputs * max_n_classes
13301329
has_oob_sample[node_idx] = 1
13311330
node = &self.nodes[node_idx]
13321331
for k in range(n_outputs):
@@ -1338,7 +1337,8 @@ cdef class Tree:
13381337
count_oob_values[node_idx, k] += 1
13391338
else:
13401339
if method == "ufi":
1341-
oob_node_values[node_idx, 0, k] += (y_test[k, sample_idx] - value_at_node[k]) ** 2.0
1340+
node_value_idx = node_idx * self.value_stride + k * max_n_classes
1341+
oob_node_values[node_idx, 0, k] += (y_test[k, sample_idx] - self.value[node_value_idx]) ** 2.0
13421342
else:
13431343
oob_node_values[node_idx, 0, k] += y_test[k, sample_idx]
13441344
count_oob_values[node_idx, k] += 1
@@ -1354,12 +1354,13 @@ cdef class Tree:
13541354
for c in range(n_classes[k]):
13551355
oob_node_values[node_idx, c, k] /= count_oob_values[node_idx, k]
13561356
# if leaf store the predictive proba
1357-
if self.nodes[node_idx].left_child == _TREE_LEAF or self.nodes[node_idx].right_child == _TREE_LEAF:
1357+
if self.nodes[node_idx].left_child == _TREE_LEAF and self.nodes[node_idx].right_child == _TREE_LEAF:
13581358
for sample_idx in range(n_samples):
13591359
if y_leafs[sample_idx] == node_idx:
13601360
for k in range(n_outputs):
13611361
for c in range(n_classes[k]):
1362-
oob_pred[sample_idx, c, k] = oob_node_values[node_idx, c, k]
1362+
node_value_idx = node_idx * self.value_stride + k * max_n_classes + c
1363+
oob_pred[sample_idx, c, k] = self.value[node_value_idx]
13631364

13641365
cpdef compute_unbiased_feature_importance_and_oob_predictions(self, object X_test, object y_test, criterion, method="ufi"):
13651366
cdef intp_t n_samples = X_test.shape[0]
@@ -1377,12 +1378,7 @@ cdef class Tree:
13771378
cdef Node* nodes = self.nodes
13781379
cdef Node node = nodes[0]
13791380
cdef int k, c, offset, node_idx = 0
1380-
cdef int left_idx = -1
1381-
cdef int right_idx = -1
1382-
1383-
cdef float64_t* value_at_node = self.value + node_idx * n_outputs * max_n_classes
1384-
cdef float64_t* value_at_left = self.value + left_idx * n_outputs * max_n_classes
1385-
cdef float64_t* value_at_right = self.value + right_idx * n_outputs * max_n_classes
1381+
cdef int left_idx, right_idx, node_value_idx, left_value_idx, right_value_idx = -1
13861382

13871383
cdef intp_t[:, ::1] y_view = np.ascontiguousarray(y_test, dtype=np.intp)
13881384
self._compute_oob_node_values_and_predictions(X_test, y_view, oob_pred, has_oob_sample, oob_node_values, method)
@@ -1394,45 +1390,43 @@ cdef class Tree:
13941390
left_idx = node.left_child
13951391
right_idx = node.right_child
13961392
if has_oob_sample[left_idx] and has_oob_sample[right_idx]:
1397-
value_at_node = self.value + node_idx * n_outputs * max_n_classes
1398-
value_at_left = self.value + left_idx * n_outputs * max_n_classes
1399-
value_at_right = self.value + right_idx * n_outputs * max_n_classes
1400-
offset=0
14011393
if method == "ufi":
14021394
for k in range(n_outputs):
14031395
if n_classes[k] > 1: # Classification
14041396
for c in range(n_classes[k]):
1397+
node_value_idx = node_idx * self.value_stride + k * max_n_classes + c
1398+
left_value_idx = left_idx * self.value_stride + k * max_n_classes + c
1399+
right_value_idx = right_idx * self.value_stride + k * max_n_classes + c
14051400
if criterion == "gini":
14061401
importances[node.feature] -= (
1407-
value_at_node[offset + c] * oob_node_values[node_idx, c, k]
1402+
self.value[node_value_idx] * oob_node_values[node_idx, c, k]
14081403
* node.weighted_n_node_samples
14091404
-
1410-
value_at_left[offset + c] * oob_node_values[left_idx, c, k]
1405+
self.value[left_value_idx] * oob_node_values[left_idx, c, k]
14111406
* nodes[left_idx].weighted_n_node_samples
14121407
-
1413-
value_at_right[offset + c] * oob_node_values[right_idx, c, k]
1408+
self.value[right_value_idx] * oob_node_values[right_idx, c, k]
14141409
* nodes[right_idx].weighted_n_node_samples
14151410
)
14161411
elif criterion == "log_loss":
14171412
importances[node.feature] -= (
1418-
(value_at_node[offset + c] * log(oob_node_values[node_idx, c, k])
1419-
+ log(value_at_node[offset + c]) * oob_node_values[node_idx, c, k])
1413+
(self.value[node_value_idx] * log(oob_node_values[node_idx, c, k])
1414+
+ log(self.value[node_value_idx]) * oob_node_values[node_idx, c, k])
14201415
* node.weighted_n_node_samples
14211416
)
14221417
# If one of the children is pure for oob or inbag samples, set the cross entropy to 0
1423-
if oob_node_values[left_idx, c, k] > 0.0 and value_at_left[offset + c] > 0.0:
1418+
if oob_node_values[left_idx, c, k] > 0.0 and self.value[left_value_idx] > 0.0:
14241419
importances[node.feature] += (
1425-
(value_at_left[offset + c] * log(oob_node_values[left_idx, c, k])
1426-
+ log(value_at_left[offset + c]) * oob_node_values[left_idx, c, k])
1420+
(self.value[left_value_idx] * log(oob_node_values[left_idx, c, k])
1421+
+ log(self.value[left_value_idx]) * oob_node_values[left_idx, c, k])
14271422
* nodes[left_idx].weighted_n_node_samples
14281423
)
1429-
if oob_node_values[right_idx, c, k] > 0.0 and value_at_right[offset + c] > 0.0:
1424+
if oob_node_values[right_idx, c, k] > 0.0 and self.value[right_value_idx] > 0.0:
14301425
importances[node.feature] += (
1431-
(value_at_right[offset + c] * log(oob_node_values[right_idx, c, k])
1432-
+ log(value_at_right[offset + c]) * oob_node_values[right_idx, c, k])
1426+
(self.value[right_value_idx] * log(oob_node_values[right_idx, c, k])
1427+
+ log(self.value[right_value_idx]) * oob_node_values[right_idx, c, k])
14331428
* nodes[right_idx].weighted_n_node_samples
14341429
)
1435-
offset += n_classes[k]
14361430
else: # Regression
14371431
importances[node.feature] += (
14381432
(node.impurity + oob_node_values[node_idx, 0, k])
@@ -1448,16 +1442,18 @@ cdef class Tree:
14481442
elif method == "mdi_oob":
14491443
for k in range(n_outputs):
14501444
for c in range(n_classes[k]):
1445+
node_value_idx = node_idx * self.value_stride + k * max_n_classes + c
1446+
left_value_idx = left_idx * self.value_stride + k * max_n_classes + c
1447+
right_value_idx = right_idx * self.value_stride + k * max_n_classes + c
14511448
importances[node.feature] += (
1452-
(value_at_node[offset + c] - value_at_left[offset + c])
1449+
(self.value[node_value_idx] - self.value[left_value_idx])
14531450
* (oob_node_values[node_idx, c , k] - oob_node_values[left_idx, c, k])
14541451
* nodes[left_idx].weighted_n_node_samples
14551452
+
1456-
(value_at_node[offset + c] - value_at_right[offset + c])
1453+
(self.value[node_value_idx] - self.value[right_value_idx])
14571454
* (oob_node_values[node_idx, c, k] - oob_node_values[right_idx, c, k])
14581455
* nodes[right_idx].weighted_n_node_samples
14591456
)
1460-
offset += n_classes[k]
14611457
importances[node.feature] /= n_outputs
14621458
else:
14631459
raise(ValueError(method))

0 commit comments

Comments
 (0)
0