10000 helper to precompute metric params · scikit-learn/scikit-learn@d56c20f · GitHub
[go: up one dir, main page]

Skip to content

Commit d56c20f

Browse files
committed
helper to precompute metric params
1 parent 03d1a46 commit d56c20f

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

sklearn/metrics/pairwise.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,24 @@ def _check_chunk_size(reduced, chunk_size):
11291129
chunk_size))
11301130

11311131

1132+
def _precompute_metric_params(X, Y, metric=None, **kwds):
1133+
"""Precompute data-derived metric parameters if not provided
1134+
"""
1135+
if metric == "seuclidean" and 'V' not in kwds:
1136+
if X is Y:
1137+
V = np.var(X, axis=0, ddof=1)
1138+
else:
1139+
V = np.var(np.vstack([X, Y]), axis=0, ddof=1)
1140+
return {'V': V}
1141+
if metric == "mahalanobis" and 'VI' not in kwds:
1142+
if X is Y:
1143+
VI = np.linalg.inv(np.cov(X.T)).T
1144+
else:
1145+
VI = np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T
1146+
return {'VI': VI}
1147+
return {}
1148+
1149+
11321150
def pairwise_distances_chunked(X, Y=None, reduce_func=None,
11331151
metric='euclidean', n_jobs=None,
11341152
working_memory=None, **kwds):
@@ -1264,18 +1282,9 @@ def pairwise_distances_chunked(X, Y=None, reduce_func=None,
12641282
working_memory=working_memory)
12651283
slices = gen_batches(n_samples_X, chunk_n_rows)
12661284

1267-
if metric == "seuclidean" and 'V' not in kwds:
1268-
if X is Y:
1269-
V = np.var(X, axis=0, ddof=1)
1270-
else:
1271-
V = np.var(np.vstack([X, Y]), axis=0, ddof=1)
1272-
kwds.update({'V': V})
1273-
elif metric == "mahalanobis" and 'VI' not in kwds:
1274-
if X is Y:
1275-
VI = np.linalg.inv(np.cov(X.T)).T
1276-
else:
1277-
VI = np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T
1278-
kwds.update({'VI': VI})
1285+
# precompute data-derived metric params
1286+
params = _precompute_metric_params(X, Y, metric=metric, **kwds)
1287+
kwds.update(**params)
12791288

12801289
for sl in slices:
12811290
if sl.start == 0 and sl.stop == n_samples_X:
@@ -1408,18 +1417,9 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=None, **kwds):
14081417
dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else None
14091418
X, Y = check_pairwise_arrays(X, Y, dtype=dtype)
14101419

1411-
if metric == "seuclidean" and 'V' not in kwds:
1412-
if X is Y:
1413-
V = np.var(X, axis=0, ddof=1)
1414-
else:
1415-
V = np.var(np.vstack([X, Y]), axis=0, ddof=1)
1416-
kwds.update({'V': V})
1417-
elif metric == "mahalanobis" and 'VI' not in kwds:
1418-
if X is Y:
1419-
VI = np.linalg.inv(np.cov(X.T)).T
1420-
else:
1421-
VI = np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T
1422-
kwds.update({'VI': VI})
1420+
# precompute data-derived metric params
1421+
params = _precompute_metric_params(X, Y, metric=metric, **kwds)
1422+
kwds.update(**params)
14231423

14241424
if effective_n_jobs(n_jobs) == 1 and X is Y:
14251425
return distance.squareform(distance.pdist(X, metric=metric,

0 commit comments

Comments
 (0)
0