8000 k-means adding all_paris_l2_distance_squared function · seckcoder/scikit-learn@1daf251 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1daf251

Browse files
jabergGaelVaroquaux
authored andcommitted
k-means adding all_paris_l2_distance_squared function
1 parent bcf8494 commit 1daf251

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

scikits/learn/cluster/k_means_.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,34 @@
1212

1313
from ..base import BaseEstimator
1414

15+
def all_pairs_l2_distance_squared(A, B, B_norm_squared=None):
16+
"""
17+
Returns the squared l2 norms of the differences between rows of A and B.
18+
19+
Parameters
20+
----------
21+
A: array, [n_rows_A, n_cols]
22+
23+
B: array, [n_rows_B, n_cols]
24+
25+
B_norm_squared: array [n_rows_B], or None
26+
pre-computed (B**2).sum(axis=1)
27+
28+
Returns
29+
-------
30+
31+
array [n_rows_A, n_rows_B]
32+
entry [i,j] is ((A[i] - B[i])**2).sum(axis=1)
33+
34+
"""
35+
if B_norm_squared is None:
36+
B_norm_squared = (B**2).sum(axis=1)
37+
if A is B:
38+
A_norm_squared = B_norm_squared
39+
else:
40+
A_norm_squared = (A**2).sum(axis=1)
41+
return (B_norm_squared + A_norm_squared.reshape((A.shape[0],1)) - 2*np.dot(A, B.T))
42+
1543
################################################################################
1644
# Initialisation heuristic
1745

0 commit comments

Comments
 (0)
0