8000 ENH partial_dependece plot for HistGradientBoosting estimator fitted with `sample_weight` · Issue #25210 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH partial_dependece plot for HistGradientBoosting estimator fitted with sample_weight #25210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vitaliset opened this issue Dec 19, 2022 · 18 comments

Comments

@vitaliset
Copy link
Contributor

Describe the workflow you want to enable

As partial dependence of a model at a point is defined as an expectation, it should respect sample_weight if someone wishes to use it (for instance, when you know your X does not follow the distribution you are interested in).

#25209 tries to solve this for method='brute' when you have new X. For older tree-based models trained with sample_weights, method='recursion' keeps track of the training sample_weight and calculates the partial_dependece with that into consideration.

But, as discussed during the implementation of sample_weight on the HistGradientBoosting estimators (#14696 (comment)), these models stores an attribute _fitted_with_sw and when partial_dependece with recursion is asked, it throws an error:

if getattr(self, "_fitted_with_sw", False):
raise NotImplementedError(
"{} does not support partial dependence "
"plots with the 'recursion' method when "
"sample weights were given during fit "
"time.".format(self.__class__.__name__)
)

Describe your proposed solution

As discussed in #24872 (comment), the big difference between other tree-based algorithms and HistGradientBoosting is that HistGradientBoosting does not save the weighted_n_node_samples when building the tree.

Describe alternatives you've considered, if relevant

No response

Additional context

No response

@vitaliset vitaliset added Needs Triage Issue requires triage New Feature labels Dec 19, 2022
@vitaliset
Copy link
Contributor Author

It is not clear to me why it would be hard to keep track of weighted_n_node_samples as intuitively the tree should be using it during loss/criterion calculations at each node. But on the original PR for implemeting sample_weight on HistGradientBoosting, @adrinjalali points out that:

adding sample weight support for pdp requires adding a weighted_count to node_struct. It'll be a non-trivial change

@adrinjalali, would you mind explaining why do you find it non-trivial to do? I would love to try to solve this issue if I can get a little push. :D

@adrinjalali
Copy link
Member

IIRC it's because it'd increase the memory footprint of the model, as well as slowing the model down. I'm not opposed to adding them, but adding them in #14696 would have meant requiring more benchmarks, and more tests, which wasn't ideal since that PR took quite a while and quite an effort to finish in the first place.

It's certainly doable as a separate PR, and happy if you're up for pushing a PR for it, but be warned that it might be a non-trivial PR 😁

@thomasjpfan thomasjpfan added module:ensemble and removed Needs Triage Issue requires triage labels Dec 19, 2022
@Andrew-Wang-IB45
Copy link
Contributor

@vitaliset, have you started working on this? I would like to take this on as well.

@vitaliset
Copy link
Contributor Author
vitaliset commented Dec 27, 2022

Feel free to jump in and take it @Andrew-Wang-IB45. I will not be able to start it in the next few weeks and would be super happy if you could come with something in the meantime!

Please take a look at the issues I mentioned in the description of this one (especially 24872) for discussions related to the solution.

From my point of view, the solution involves two parts (maybe we should have two sequential PRs for easier reviews):

1 - Update the training function of the base estimator of HistGradientBoosting, so it keeps track of the weighted_n_node_samples (just like the other tree-based models do).

  • For this, we should be benchmarking performance and memory footprint as HistGradientBoosting should be an efficient estimator (it will probably still be, but we need to check if we are not messing it up).

2 - Update the compute_partial_dependence function of the base estimators of HistGradientBoosting, so it looks at the new weighted_n_node_samples instead of the count.

What do you think? Does this approach make sense to you?

Please reach out and ask for help if you need it! :)

@Andrew-Wang-IB45
Copy link
Contributor

Sounds good, @vitaliset. I will review the issues and pull requests that you have brought up. Your approach makes sense and I will follow that when implementing the solution. I'll let you know how things are coming along and when I need help. Thanks!

@vitaliset
Copy link
Contributor Author

I edited the last comment as the test I suggested does not make sense: when training with sample_weight, the model will be different from the one trained without it... So the leaf values for each model should be different, and even though the aggregate sum is averaging out according to sample_weight, the values we get will be different per se.

Nonetheless, some small tests should be adapted once the fix done. For instance:

# TODO: remove/fix when PDP supports HGBT with sample weights

# TODO: extend to HistGradientBoosting once sample_weight is supported

@Andrew-Wang-IB45
Copy link
Contributor

Hi @vitaliset, I thought about an approach to implement your solution and would like some clarifications.

I noticed for the BaseDecisionTree class, sample_weight is used when constructing the inner trees, as follows:

builder.build(self.tree_, X, y, sample_weight)

My understanding is that for the BaseHistGradientBoosting class, sample_weight is correctly used in the loss function and in the sampling processes for the training and validation data, so we just need to add the sample weight as a parameter to the following instantiation of the TreeGrower:

grower = TreeGrower(
X_binned=X_binned_train,
gradients=g_view[:, k],
hessians=h_view[:, k],
n_bins=n_bins,
n_bins_non_missing=self._bin_mapper.n_bins_non_missing_,
has_missing_values=has_missing_values,
is_categorical=self.is_categorical_,
monotonic_cst=monotonic_cst,
interaction_cst=interaction_cst,
max_leaf_nodes=self.max_leaf_nodes,
max_depth=self.max_depth,
min_samples_leaf=self.min_samples_leaf,
l2_regularization=self.l2_regularization,
shrinkage=self.learning_rate,
n_threads=n_threads,
)

Within the TreeGrower, we should be using sample_weight to compute weighted_n_node_samples, storing that as an attribute for each TreeNode. Currently, the only reference I have to computing weighted_n_node_samples is for the TreeBuilder class used in the BaseDecisionTree, where self.weighted_n_node_samples is used in the PDP calculations:

self.sample_weight = sample_weight
self.sample_indices = sample_indices
self.start = start
self.end = end
self.n_node_samples = end - start
self.weighted_n_samples = weighted_n_samples
self.weighted_n_node_samples = 0.0
cdef SIZE_t i
cdef SIZE_t p
cdef SIZE_t k
cdef SIZE_t c
cdef DOUBLE_t w = 1.0
for k in range(self.n_outputs):
memset(&self.sum_total[k, 0], 0, self.n_classes[k] * sizeof(double))
for p in range(start, end):
i = sample_indices[p]
# w is originally set to be 1.0, meaning that if no sample weights
# are given, the default weight of each sample is 1.0.
if sample_weight is not None:
w = sample_weight[i]
# Count weighted class frequency for each target
for k in range(self.n_outputs):
c = <SIZE_t> self.y[i, k]
self.sum_total[k, c] += w
self.weighted_n_node_samples += w

The above calculation should be the one used to compute weighted_n_node_samples for BaseHistGradientBoosting instances that are fitted with sample_weight, correct? If so, I would incorporate this calculation in _intilialize_root and split_next, specifically after the instantiation of each TreeNode, and save that value as an attribute.

Once this is done, then we can simply change all references of count to weighted_n_node_samples in

left_sample_frac = (
<Y_DTYPE_C> nodes[current_node.left].count /
current_node.count)
and then remove the NotImplementedError from the recursive version of compute_partial_dependence since this is now supported.

What are your thoughts on my approach? If this works, I was thinking we should split this into two pull requests, where the first one is changing the underlying TreeNode to incorporate sample_weight and the second one is changing the PDP calculations and public interface. Clearly, we should benchmarking performance and memory footprint upon updating the BaseHistGradientBoostingClass. What tools would you recommend for this?

Thanks for your time!

@vitaliset
Copy link
Contributor Author
vitaliset commented Jan 10, 2023

From my shallow understanding of BaseHistGradientBoosting implementation, this looks about right! Awesome job unraveling the details! :D

For benchmarking, scikit-learn uses airspeed velocity to monitor performance, and I think this PR should run the benchmarks for the HistGradientBoosting from there. You can follow this tutorial on how to use it to compare your branch versus main branch.

But I also think that some timeit.timeit runs should be useful here. Something like:

from timeit import timeit
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.datasets import make_classification

def fit_predict_hist(n_samples):
    X, y = make_classification(n_samples=n_samples)
    hgbc = HistGradientBoostingClassifier(random_state=42).fit(X, y)
    probs = hgbc.predict_proba(X)

n_samples_list = [100, 1000, 5000]
times = [timeit(lambda: fit_predict_hist(n_samples), number=10) for n_samples in n_samples_list]
print(times)
>>> [2.026070300000015, 12.3248289, 15.59315509999999]

This is just a draft, and you can certainly:

  • change the make_classification to other sklearn.dataset sets (or add them to the function, so it runs many different types - including datasets with categorical type data);
  • add something for regression;
  • change the parameters such as number from timeit.timeit and n_samples_list.

You can save it and then plot the main vs your branch to compare the change:

from sklearn import __version__
import joblib
joblib.dump(times, f"time_list_{__version__}")

Update: actually, using timeit.repeat might be better for having a sense of variance:

from timeit import repeat
...
times = [repeat(lambda: fit_predict_hist(n_samples), repeat=5, number=10) for n_samples in n_samples_list]
np.array(times).mean(axis=1)
>>> array([ 1.92954972, 11.46854726, 13.89775882])
np.array(times).std(axis=1)
>>> array([0.01262523, 0.30100737, 0.16548444])

@Andrew-Wang-IB45
Copy link
Contributor

Hi, thanks for your feedback and suggestions! I will get to it and keep you updated.

@Andrew-Wang-IB45
Copy link
Contributor

Hi @vitaliset, I have updated the BaseHistGradientBoosting fit function and the Grower class to keep track of weighted_n_node_samples using the sample_weight passed in. I noticed that for the BaseDecisionTree, one of its attributes came from the parameter min_weight_fraction_leaf, which is used to calculate the min_weight_leaf, which itself is used to determine if a node is a leaf, which are respectively

self.min_weight_fraction_leaf = min_weight_fraction_leaf
# Set min_weight_leaf from min_weight_fraction_leaf
if sample_weight is None:
min_weight_leaf = self.min_weight_fraction_leaf * n_samples
else:
min_weight_leaf = self.min_weight_fraction_leaf * np.sum(sample_weight)
is_leaf = (depth >= max_depth or
n_node_samples < min_samples_split or
n_node_samples < 2 * min_samples_leaf or
weighted_n_node_samples < 2 * min_weight_leaf)

While I think that is is good to include this in the tree for the Grower, doing so would require changing the signature of the init function for BaseHistGradientBoosting. I am wondering if we should change it now for completeness or leave it as is to avoid changing the public interface too much.

Regardless, I will do some tests to check for its correctness and performance. Do you have a small example of the expected outputs of the BaseHistGradientBoosting with and without sample weights?

@vitaliset
Copy link
Contributor Author

Hello @Andrew-Wang-IB45! If BaseDecisionTree uses it on its init, I would not be opposed to doing the same in BaseHistGradientBoosting following the same logic. Feel free to choose one, and during review time, maintainers will give their opinion.

Regarding tests, I think you can follow the logic here to see if it is doing what it is supposed to do:

import pandas as pd
import numpy as np
from sklearn import __version__

print("sklearn version:", __version__)
>>> sklearn version: 1.1.3

from sklearn.datasets import make_classification
from sklearn.ensemble import HistGradientBoostingClassifier

X, y = make_classification(n_samples=1000, weights=(0.9,), random_state=42)
sample_weight = np.random.RandomState(42).exponential(scale=1 + 4*y)

hgbc = (
    HistGradientBoostingClassifier(random_state=42, max_depth=1, max_iter=1)
    .fit(X, y, sample_weight)
)

pd.DataFrame(hgbc._predictors[0][0].nodes)

image

aux = pd.DataFrame(hgbc._predictors[0][0].nodes).loc[0]

feat_idx = int(aux.feature_idx)
num_tresh = aux.num_threshold

sample_weight[X[:, feat_idx] >= num_tresh].sum(), sample_weight[X[:, feat_idx] < num_tresh].sum()
>>> (456.6039224516335, 897.213567789653)

These numbers should be the weighted_n_node_samples of each leaf (note that max_depth=1 and we have only one predictor).

@Andrew-Wang-IB45
Copy link
Contributor

Hi @vitaliset, as you can see from my draft pull request, I have added support for using the user-inputted sample weights to compute and keep track of the weighted_n_node_samples when growing the tree. Currently, I am attempting to add some tests to reflect these changes, specifically in test_grower.py and possibly test_predictor.py and test_gradient_boosting.py, but I am struggling with what parameters to use for sample_weight and min_weight_leaf. Any assistance would be highly appreciated. In the meantime, I am working on documenting the performance and memory footprint of the updated HistGradientBoosting model.

@vitaliset
Copy link
Contributor Author
vitaliset commented Feb 5, 2023

Hello @Andrew-Wang-IB45, the first test I would implement would be similar to the code I gave here. I don't know if I was clear, but you can, for instance, assert that the number you get calculating the sum of the sample_weights in the leaf is equal to the "attribute" weighted_n_node_samples you implemented. For it, you can do a calculation "outside" (using the code I wrote earlier, for example):

sample_weight[X[:, feat_idx] >= num_tresh].sum(), sample_weight[X[:, feat_idx] < num_tresh].sum()
>>> (456.6039224516335, 897.213567789653)

In this example, one leaf should have 456 as weighted_n_node_samples, and the other should have 897. This logic can be executed for multiple datasets and randomly picked sample_weight.

Another interesting test might be having only two unique values for X (and then we know which sample went where), for instance, X = np.array(10*[0] + 10*[1]) (and y = X) with sample_weight = 10*[3] + 10*[7]. All your trees should have 30 and 70 for their leafs' weighted_n_node_samples and 100 weighted_n_node_samples for their root. Also, when sample_weight = 20*[1], then you can assert weighted_n_node_samples == 20 for each leaf. When sample_weight = 20*[0.5], then you can assert weighted_n_node_samples == 10 for each leaf etc.

Once you are comfortable with the code, even if you think you still need extra tests, ask a core contributor to review the PR, and they will suggest additional tests! :D

@Andrew-Wang-IB45
Copy link
Contributor

Hi @vitaliset, thanks for your suggestions! I have incorporated them into my branch, cleaned it up, and made it available for review. Out of curiosity, how do I request for a reviewer for a pull request?

@Andrew-Wang-IB45
Copy link
Contributor

Hi @vitaliset. For the pull request to incorporate the weighted_n_node_samples into the tree, it is currently on hold due to the significant performance degradation. I have tried different methods to speed up the code but unfortunately they are still too slow. Do you have any suggestions for improving the performance or finding another, less costly, approach to achieve the same effects?

@vitaliset
Copy link
Contributor Author

Hi @Andrew-Wang-IB45. Simplifying and taking a smaller step might be a way. I suggested on your PR not to do extra calculations beyond the essential ones to store the weighted_n_node_samples. I hope this works!

@Andrew-Wang-IB45
Copy link
Contributor

Hi @vitaliset, thanks for your suggestions. I am currently removing the extraneous code so that it only computes the weighted_n_node_samples along the way. I will let you how that goes.

@lorentzenchr
Copy link
Member
lorentzenchr commented Apr 8, 2023

The conclusion of #25431 is that it is too costly to add sample weigths per default. I opened #26128 to discuss a possible way forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants
0