-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
partial_dependence should respect sample weights #24872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@mayer79, have you already started working on this issue? I would love to solve it if you didn't. :) |
@vitaliset Not yet started! I would be super happy if you could dig into this. I think there are two ways to calculate PDPs. For the model agnostic logic, we would probably need to replace |
Thanks for giving me the direction I should follow, @mayer79! ;) I took a look at it during the last few days, and here is what I found:
|
grid, values = _grid_from_X( | |
_safe_indexing(X, features_indices, axis=1), | |
percentiles, | |
is_categorical, | |
grid_resolution, | |
) |
Note that grid
is just the grid of values we will iterate over for the PDP calculations:
import numpy as np
from sklearn import __version__ as v
print("numpy:", np.__version__, ". sklearn:", v)
>>> numpy: 1.23.3 . sklearn: 1.1.3
from sklearn.inspection._partial_dependence import _grid_from_X
from sklearn.utils import _safe_indexing
from sklearn.datasets import load_diabetes
X, _ = load_diabetes(return_X_y=True)
grid, values = \
_grid_from_X(X=_safe_indexing(X, [2, 8], axis=1), percentiles=(0,1), grid_resolution=100)
print("original shape of X:", X.shape, "shape of grid:", grid.shape)
>>> original shape of X: (442, 10) shape of grid: (10000, 2)
print(len(values), values[0].shape)
>>> 2 (100,)
from itertools import product
print((grid == np.array(list(product(values[0], values[1])))).all())
>>> True
The grid
variable is not X
with repeated rows (for each value of the grid) like we would expect for method='brute'
. Inside the _partial_dependence_brute
function we actually do this later:
scikit-learn/sklearn/inspection/_partial_dependence.py
Lines 160 to 163 in c1cfc4d
X_eval = X.copy() | |
for new_values in grid: | |
for i, variable in enumerate(features): | |
_safe_assign(X_eval, new_values[i], column_indexer=variable) |
This grid
variable is what is being passed on the PDP calculations, not X
:
scikit-learn/sklearn/inspection/_partial_dependence.py
Lines 119 to 120 in c1cfc4d
def _partial_dependence_recursion(est, grid, features): | |
averaged_predictions = est._compute_partial_dependence_recursion(grid, features) |
When looking for the average for a specific value in the grid, it does one run on the tree and checks the proportion of samples (from the training data) that pass through each leaf when we have a split (when the feature of the split is not the feature we are making the dependence plot of).
scikit-learn/sklearn/tree/_tree.pyx
Lines 1225 to 1227 in 9268eea
left_sample_frac = ( | |
self.nodes[current_node.left_child].weighted_n_node_samples / | |
current_node.weighted_n_node_samples) |
Note that weighted_n_node_samples
is an attribute from the tree.
method='recursion'
uses the sample_weight
from training data... but not always
Nonetheless, I found something "odd". There are two slightly different implementations of the compute_partial_dependence
function on scikit-learn—one for the models based on the CART implementation and one for the estimators of the HistGradientBoosting. The algorithms based on the CART implementation use the sample_weight
of the .fit
method through the weighted_n_node_samples
attribute (code above).
While the estimators of HistGradientBoosting doesn't. It just counts the number of samples on the leaf (even if it was fitted with sample_weight
).
scikit-learn/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
Lines 190 to 192 in ff6f880
left_sample_frac = ( | |
<Y_DTYPE_C> nodes[current_node.left].count / | |
current_node.count) |
You can see that looks right from this small code I ran:
import numpy as np
from sklearn import __version__ as v
print("numpy:", np.__version__, ". sklearn:", v)
>>> numpy: 1.23.3 . sklearn: 1.1.3
from sklearn.datasets import load_diabetes
X, y = load_diabetes(return_X_y=True)
sample_weights = np.random.RandomState(42).uniform(0, 1, size=X.shape[0])
from sklearn.tree import DecisionTreeRegressor
dtr_nsw = DecisionTreeRegressor(max_depth=1, random_state=42).fit(X, y)
dtr_sw = DecisionTreeRegressor(max_depth=1, random_state=42).fit(X, y, sample_weight=sample_weights)
print(dtr_nsw.tree_.weighted_n_node_samples, dtr_sw.tree_.weighted_n_node_samples)
>>> [442. 218. 224.] [218.28015122 108.90401865 109.37613257]
from sklearn.ensemble import RandomForestRegressor
rfr_nsw = RandomForestRegressor(max_depth=1, random_state=42).fit(X, y)
rfr_sw = RandomForestRegressor(max_depth=1, random_state=42).fit(X, y, sample_weight=sample_weights)
print(rfr_nsw.estimators_[0].tree_.weighted_n_node_samples, rfr_sw.estimators_[0].tree_.weighted_n_node_samples)
>>> [442. 288. 154.] [226.79228463 148.44294465 78.34933998]
from sklearn.ensemble import HistGradientBoostingRegressor
hgbr_nsw = HistGradientBoostingRegressor(max_depth=2, random_state=42).fit(X, y)
hgbr_sw = HistGradientBoostingRegressor(max_depth=2, random_state=42).fit(X, y, sample_weight=sample_weights)
import pandas as pd
pd.DataFrame(hgbr_nsw._predictors[0][0].nodes)
pd.DataFrame(hgbr_sw._predictors[0][0].nodes)
The weighted_n_node_samples
attribute takes weighting in count (as it is a float
) while the .count
from predictors looks only at the number of samples at each node (as it is an int
).
Takeaways
- The
method='brute'
should be straightforward, and I'll create a PR for it soon. I'm still determining the tests, but I can add extra ones during review time. - Because it explicitly doesn't use the
X
for the PDP calculations when we havemethod='recursion'
, I don't think it makes sense to try to implementsample_weight
on it, and I'll create an error for it. - Nonetheless, we can discuss the mismatch between calculating the PDP with training
sample_weight
or not that we see using different models and make it uniform across algorithms if we think this is relevant. It doesn't look like a big priority, but knowing we have this problem is nice. I don't think it should be that hard to keep track of the weighted samples on thenodes
attribute.
Fantastic research! Additional possible tests for "brute": PDP unweighted is the same as PDP with all weights 1.0. Same for all weights 2.0. My feeling is : the "recurse" approach for trees should respect sample weights of the training data when tracking split weights. I cannot explain why the two tree-methods are different. Should we open an issue for clarification? |
That's correct - this is something that is explicitly not supported in the HGBDT trees yet: scikit-learn/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py Lines 1142 to 1148 in 205f3b7
(See #14696 (comment) for historical decision / context) I thought there was an open issue for that, but it looks like there isn't. Feel free to open one! |
Describe the workflow you want to enable
Currently, the inspect.partial_dependence funtions calculate arithmetic averages over predictions. For models fitted with sample weights, this is between suboptimal and wrong.
Describe your proposed solution
Add new argument "sample_weight = None". If vector of right length, replace arithmetic average of predictions by weighted averages.
Note that this does not affect the calculation of ICE curves, just the aggregate.
Describe alternatives you've considered, if relevant
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: