|
5 | 5 | # Thierry Guillemot <thierry.guillemot.work@gmail.com>
|
6 | 6 | # License: BSD 3 clause
|
7 | 7 |
|
8 |
| -from itertools import chain |
9 |
| -from itertools import permutations |
10 |
| - |
11 | 8 | import numpy as np
|
12 | 9 |
|
13 |
| -from ...externals.six.moves import xrange |
14 | 10 | from ...utils import check_array
|
15 | 11 | from ...utils import check_consistent_length
|
16 | 12 | from ...utils import check_random_state
|
17 | 13 | from ...utils import check_X_y
|
18 | 14 | from ...utils.fixes import bincount
|
19 | 15 | from ..pairwise import pairwise_distances
|
20 | 16 | from ...preprocessing import LabelEncoder
|
| 17 | +from .supervised import contingency_matrix |
21 | 18 |
|
22 | 19 |
|
23 | 20 | def check_number_of_labels(n_labels, n_samples):
|
@@ -304,27 +301,15 @@ def prediction_strength_score(labels_train, labels_test):
|
304 | 301 | labels_test = check_array(labels_test, dtype=np.int32, ensure_2d=False,
|
305 | 302 | warn_on_dtype=True)
|
306 | 303 |
|
307 |
| - clusters = set(chain(labels_train, labels_test)) |
308 |
| - n_clusters = len(clusters) |
| 304 | + n_clusters = max(np.unique(labels_train).shape[0], |
| 305 | + np.unique(labels_test).shape[0]) |
309 | 306 | if n_clusters == 1:
|
310 | 307 | return 1.0 # by definition
|
311 | 308 |
|
312 |
| - strength = 1.0 |
313 |
| - for k in clusters: |
314 |
| - # samples assigned to k-th cluster based on test data |
315 |
| - samples_test_k = np.flatnonzero(labels_test == k) |
316 |
| - cluster_test_size = samples_test_k.shape[0] |
317 |
| - |
318 |
| - if cluster_test_size < 2: |
319 |
| - continue |
320 |
| - |
321 |
| - matches = 0 |
322 |
| - for i, j in permutations(xrange(cluster_test_size), 2): |
323 |
| - ki, kj = samples_test_k[j], samples_test_k[i] |
324 |
| - if labels_train[ki] == labels_train[kj]: |
325 |
| - matches += 1 |
326 |
| - |
327 |
| - strength = min(strength, matches / (cluster_test_size * |
328 |
| - (cluster_test_size - 1.))) |
| 309 | + C = contingency_matrix(labels_train, labels_test) |
| 310 | + pairs_matching = (C * (C - 1) / 2).sum(axis=0) |
| 311 | + M = C.sum(axis=0) |
| 312 | + pairs_total = (M * (M - 1) / 2) |
| 313 | + nz = pairs_total.nonzero()[0] |
329 | 314 |
|
330 |
| - return strength |
| 315 | + return (pairs_matching[nz] / pairs_total[nz]).min() |
0 commit comments