File tree 1 file changed +11
-3
lines changed 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -260,6 +260,13 @@ def calinski_harabaz_score(X, labels):
260
260
(intra_disp * (n_labels - 1. )))
261
261
262
262
263
+ def _non_zero_add (sparse_matrix , value ):
264
+ """Add value to non-zero entries of a sparse matrix"""
265
+ M = sparse_matrix .copy ()
266
+ M .data += value
267
+ return M
268
+
269
+
263
270
def prediction_strength_score (labels_train , labels_test ):
264
271
"""Compute the prediction strength score.
265
272
@@ -306,9 +313,10 @@ def prediction_strength_score(labels_train, labels_test):
306
313
if n_clusters == 1 :
307
314
return 1.0 # by definition
308
315
309
- C = contingency_matrix (labels_train , labels_test )
310
- pairs_matching = (C * (C - 1 ) / 2 ).sum (axis = 0 )
311
- M = C .sum (axis = 0 )
316
+ C = contingency_matrix (labels_train , labels_test , sparse = True )
317
+ Cp = C .multiply (_non_zero_add (C , - 1 )) / 2
318
+ pairs_matching = np .asarray (Cp .sum (axis = 0 )).ravel ()
319
+ M = np .asarray (C .sum (axis = 0 )).ravel ()
312
320
pairs_total = (M * (M - 1 ) / 2 ).astype (np .float_ )
313
321
nz = pairs_total .nonzero ()[0 ]
314
322
You can’t perform that action at this time.
0 commit comments