8000 PERF Call KDTree in mutual_info to reduce memory footprint (#17878) · thomasjpfan/scikit-learn@0b0afd2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0b0afd2

Browse files
noelanothomasjpfan
andauthored
PERF Call KDTree in mutual_info to reduce memory footprint (scikit-learn#17878)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 0cfe98b commit 0b0afd2

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

doc/whats_new/v0.24.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ Changelog
147147
:pr:`17090` by :user:`Lisa Schwetlick <lschwetlick>` and
148148
:user:`Marija Vlajic Wheeler <marijavlajic>`.
149149

150+
- |Efficiency| Reduce memory footprint in :func:`feature_selection.mutual_info_classif`
151+
and :func:`feature_selection.mutual_info_regression` by calling
152+
:class:`neighbors.KDTree` for counting nearest neighbors.
153+
:pr:`17878` by :user:`Noel Rogers <noelano>`
154+
150155
:mod:`sklearn.gaussian_process`
151156
...............................
152157

sklearn/feature_selection/_mutual_info.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from scipy.special import digamma
77

88
from ..metrics.cluster import mutual_info_score
9-
from ..neighbors import NearestNeighbors
9+
from ..neighbors import NearestNeighbors, KDTree
1010
from ..preprocessing import scale
1111
from ..utils import check_random_state
1212
from ..utils.fixes import _astype_copy_false
@@ -58,17 +58,15 @@ def _compute_mi_cc(x, y, n_neighbors):
5858
radius = nn.kneighbors()[0]
5959
radius = np.nextafter(radius[:, -1], 0)
6060

61-
# Algorithm is selected explicitly to allow passing an array as radius
62-
# later (not all algorithms support this).
63-
nn.set_params(algorithm='kd_tree')
61+
# KDTree is explicitly fit to allow for the querying of number of
62+
# neighbors within a specified radius
63+
kd = KDTree(x, metric='chebyshev')
64+
nx = kd.query_radius(x, radius, count_only=True, return_distance=False)
65+
nx = np.array(nx) - 1.0
6466

65-
nn.fit(x)
66-
ind = nn.radius_neighbors(radius=radius, return_distance=False)
67-
nx = np.array([i.size for i in ind])
68-
69-
nn.fit(y)
70-
ind = nn.radius_neighbors(radius=radius, return_distance=False)
71-
ny = np.array([i.size for i in ind])
67+
kd = KDTree(y, metric='chebyshev')
68+
ny = kd.query_radius(y, radius, count_only=True, return_distance=False)
69+
ny = np.array(ny) - 1.0
7270

7371
mi = (digamma(n_samples) + digamma(n_neighbors) -
7472
np.mean(digamma(nx + 1)) - np.mean(digamma(ny + 1)))
@@ -135,10 +133,9 @@ def _compute_mi_cd(c, d, n_neighbors):
135133
c = c[mask]
136134
radius = radius[mask]
137135

138-
nn.set_params(algorithm='kd_tree')
139-
nn.fit(c)
140-
ind = nn.radius_neighbors(radius=radius, return_distance=False)
141-
m_all = np.array([i.size for i in ind])
136+
kd = KDTree(c)
137+
m_all = kd.query_radius(c, radius, count_only=True, return_distance=False)
138+
m_all = np.array(m_all) - 1.0
142139

143140
mi = (digamma(n_samples) + np.mean(digamma(k_all)) -
144141
np.mean(digamma(label_counts)) -

0 commit comments

Comments
 (0)
0