|
6 | 6 | from scipy.special import digamma
|
7 | 7 |
|
8 | 8 | from ..metrics.cluster import mutual_info_score
|
9 |
| -from ..neighbors import NearestNeighbors |
| 9 | +from ..neighbors import NearestNeighbors, KDTree |
10 | 10 | from ..preprocessing import scale
|
11 | 11 | from ..utils import check_random_state
|
12 | 12 | from ..utils.fixes import _astype_copy_false
|
@@ -58,17 +58,15 @@ def _compute_mi_cc(x, y, n_neighbors):
|
58 | 58 | radius = nn.kneighbors()[0]
|
59 | 59 | radius = np.nextafter(radius[:, -1], 0)
|
60 | 60 |
|
61 |
| - # Algorithm is selected explicitly to allow passing an array as radius |
62 |
| - # later (not all algorithms support this). |
63 |
| - nn.set_params(algorithm='kd_tree') |
| 61 | + # KDTree is explicitly fit to allow for the querying of number of |
| 62 | + # neighbors within a specified radius |
| 63 | + kd = KDTree(x, metric='chebyshev') |
| 64 | + nx = kd.query_radius(x, radius, count_only=True, return_distance=False) |
| 65 | + nx = np.array(nx) - 1.0 |
64 | 66 |
|
65 |
| - nn.fit(x) |
66 |
| - ind = nn.radius_neighbors(radius=radius, return_distance=False) |
67 |
| - nx = np.array([i.size for i in ind]) |
68 |
| - |
69 |
| - nn.fit(y) |
70 |
| - ind = nn.radius_neighbors(radius=radius, return_distance=False) |
71 |
| - ny = np.array([i.size for i in ind]) |
| 67 | + kd = KDTree(y, metric='chebyshev') |
| 68 | + ny = kd.query_radius(y, radius, count_only=True, return_distance=False) |
| 69 | + ny = np.array(ny) - 1.0 |
72 | 70 |
|
73 | 71 | mi = (digamma(n_samples) + digamma(n_neighbors) -
|
74 | 72 | np.mean(digamma(nx + 1)) - np.mean(digamma(ny + 1)))
|
@@ -135,10 +133,9 @@ def _compute_mi_cd(c, d, n_neighbors):
|
135 | 133 | c = c[mask]
|
136 | 134 | radius = radius[mask]
|
137 | 135 |
|
138 |
| - nn.set_params(algorithm='kd_tree') |
139 |
| - nn.fit(c) |
140 |
| - ind = nn.radius_neighbors(radius=radius, return_distance=False) |
141 |
| - m_all = np.array([i.size for i in ind]) |
| 136 | + kd = KDTree(c) |
| 137 | + m_all = kd.query_radius(c, radius, count_only=True, return_distance=False) |
| 138 | + m_all = np.array(m_all) - 1.0 |
142 | 139 |
|
143 | 140 | mi = (digamma(n_samples) + np.mean(digamma(k_all)) -
|
144 | 141 | np.mean(digamma(label_counts)) -
|
|
0 commit comments