8000 [MRG] GBDT: Reuse allocated memory of other histograms by NicolasHug · Pull Request #14392 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] GBDT: Reuse allocated memory of other histograms #14392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,19 +188,24 @@ Changelog
method `staged_predict`, which allows monitoring of each stage.
:pr:`16985` by :user:`Hao Chun Chang <haochunchang>`.

- |Efficiency| break cyclic references in the tree nodes used internally in
- |Efficiency| Various improvements were made to
:class:`ensemble.HistGradientBoostingRegressor` and
:class:`ensemble.HistGradientBoostingClassifier` to allow for the timely
garbage collection of large intermediate datastructures and to improve memory
usage in `fit`. :pr:`18334` by `Olivier Grisel`_ `Nicolas Hug`_, `Thomas
Fan`_ and `Andreas Müller`_.

- |Efficiency| Histogram initialization is now done in parallel in
:class:`ensemble.HistGradientBoostingRegressor` and
:class:`ensemble.HistGradientBoostingClassifier` which results in speed
improvement for problems that build a lot of nodes on multicore machines.
:pr:`18341` by `Olivier Grisel`_, `Nicolas Hug`_, `Thomas Fan`_, and
:user:`Egor Smirnov <SmirnovEgorRu>`.
:class:`ensemble.HistGradientBoostingClassifier` which lead to less memory
usage, as well as faster training times:

- break cyclic references in the tree nodes used internally to allow for
the timely garbage collection of large intermediate datastructures and to
improve memory usage in `fit`. :pr:`18334` by `Olivier Grisel`_ `Nicolas
Hug`_, `Thomas Fan`_ and `Andreas Müller`_.

- Histogram initialization is now done in parallel which results in speed
improvement on multicore machines, for problems that build a lot of nodes.
:pr:`18341` by `Olivier Grisel`_, `Nicolas Hug`_, `Thomas Fan`_, and
:user:`Egor Smirnov <SmirnovEgorRu>`.

- Allocated histograms can be reused by other nodes of the same tree,
leading to fewer memory allocations. :pr:`14392` by `Olivier Grisel`_,
`Nicolas Hug`_, `Thomas Fan`_.

- |API|: The parameter ``n_classes_`` is now deprecated in
:class:`ensemble.GradientBoostingRegressor` and returns `1`.
Expand Down
2 changes: 2 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,10 +485,12 @@ def split_next(self):
# for leaf nodes since they won't be split.
for child in (left_child_node, right_child_node):
if child.is_leaf:
self.histogram_builder.release(child.histograms)
del child.histograms

# Release memory used by histograms as they are no longer needed for
# internal nodes once children histograms have been computed.
self.histogram_builder.release(node.histograms)
del node.histograms

return left_child_node, right_child_node
Expand Down
31 changes: 23 additions & 8 deletions sklearn/ensemble/_hist_gradient_boosting/histogram.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ cdef class HistogramBuilder:
G_H_DTYPE_C [::1] ordered_gradients
G_H_DTYPE_C [::1] ordered_hessians
unsigned char hessians_are_constant
list available_histograms

def __init__(self, const X_BINNED_DTYPE_C [::1, :] X_binned,
unsigned int n_bins, G_H_DTYPE_C [::1] gradients,
Expand All @@ -103,6 +104,26 @@ cdef class HistogramBuilder:
self.ordered_hessians = hessians.copy()
self.hessians_are_constant = hessians_are_constant

# list of histograms that can be re-used for other nodes.
self.available_histograms = []

def allocate_or_reuse_hists(HistogramBuilder self):
"""Return a non-initialized histograms array.

The array is allocated only if needed.
"""
if self.available_histograms:
return self.available_histograms.pop()
else:
return np.empty(
shape=(self.n_features, self.n_bins),
dtype=HISTOGRAM_DTYPE
)

def release(HistogramBuilder self, histograms):
"""Mark a histograms array as available so it can be reused by other nodes"""
self.available_histograms.append(histograms)

def compute_histograms_brute(
HistogramBuilder self,
const unsigned int [::1] sample_indices): # IN
Expand Down Expand Up @@ -133,10 +154,7 @@ cdef class HistogramBuilder:
G_H_DTYPE_C [::1] ordered_hessians = self.ordered_hessians
G_H_DTYPE_C [::1] hessians = self.hessians
# Histograms will be initialized to zero later within a prange
hist_struct [:, ::1] histograms = np.empty(
shape=(self.n_features, self.n_bins),
dtype=HISTOGRAM_DTYPE
)
hist_struct [:, ::1] histograms = self.allocate_or_reuse_hists()

with nogil:
n_samples = sample_indices.shape[0]
Expand Down Expand Up @@ -234,10 +252,7 @@ cdef class HistogramBuilder:
cdef:
int feature_idx
int n_features = self.n_features
hist_struct [:, ::1] histograms = np.empty(
shape=(self.n_features, self.n_bins),
dtype=HISTOGRAM_DTYPE
)
hist_struct [:, ::1] histograms = self.allocate_or_reuse_hists()

for feature_idx in prange(n_features, schedule='static', nogil=True):
# Compute histogram of each feature
Expand Down
0