-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Option to return full decision paths when predicting with decision trees or random forest #2937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
65a9042
423a41e
facf67f
826ba78
94c2f87
068f05b
2881b60
04098ff
35ad7f5
7c0e173
6fe3bb7
1a93cf7
8f3eef6
649c607
6b6d889
c5a33c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,7 +98,12 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, | |
else: | ||
tree.fit(X, y, sample_weight=sample_weight, check_input=False) | ||
|
||
return tree | ||
return tree | ||
|
||
|
||
#def _parallel_predict_paths(trees, X): | ||
# """Private function used to compute a batch of prediction paths within a job.""" | ||
# return [tree.decision_paths(X) for tree in trees] | ||
|
||
|
||
def _parallel_helper(obj, methodname, *args, **kwargs): | ||
|
@@ -299,6 +304,46 @@ def _validate_X_predict(self, X): | |
|
||
return self.estimators_[0]._validate_X_predict(X, check_input=True) | ||
|
||
|
||
|
||
def decision_paths(self, X): | ||
"""Predict class or regression value for X and return decision paths leading to the prediction, from every tree. | ||
|
||
|
||
Parameters | ||
---------- | ||
X : array-like of shape = [n_samples, n_features] | ||
The input samples. | ||
|
||
Returns | ||
------- | ||
y : list of arrays with shape = [n_estimators, n_samples, max_depth + 1] | ||
Decision paths for each each tree and for eachprediction. | ||
Each path is an array of node ids, starting with the root node id. | ||
If a path is shorter than max_depth + 1, it is padded with -1 on the right. | ||
""" | ||
|
||
# Check data | ||
if getattr(X, "dtype", None) != DTYPE or X.ndim != 2: | ||
X = array2d(X, dtype=DTYPE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use the |
||
|
||
# Assign chunk of trees to jobs | ||
n_jobs, n_trees, starts = _partition_estimators(self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only n_jobs is needed
|
||
|
||
# Parallel loop | ||
path_list = Parallel(n_jobs=n_jobs, verbose=self.verbose, | ||
backend="threading")( | ||
#delayed(_parallel_predict_paths)( | ||
# self.estimators_[starts[i]:starts[i + 1]], X) | ||
#for i in range(n_jobs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you remove this? |
||
|
||
delayed(_parallel_helper)(e, 'decision_paths', X) | ||
for e in self.estimators_) | ||
|
||
#unpack the nested list and return | ||
return [lst for med_lst in path_list for lst in med_lst] | ||
|
||
|
||
@property | ||
def feature_importances_(self): | ||
"""Return the feature importances (the higher, the more important the | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove this?