8000 helper to precompute metric params · scikit-learn/scikit-learn@db65a96 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit db65a96

Browse files
committed
helper to precompute metric params
1 parent 5820938 commit db65a96

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

sklearn/metrics/pairwise.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,24 @@ def _check_chunk_size(reduced, chunk_size):
11291129
chunk_size))
11301130

11311131

1132+
def _precompute_metric_params(X, Y, metric=None, **kwds):
1133+
"""Precompute data-derived metric parameters if not provided
1134+
"""
1135+
if metric == "seuclidean" and 'V' not in kwds:
1136+
if X is Y:
1137+
V = np.var(X, axis=0, ddof=1)
1138+
else:
1139+
V = np.var(np.vstack([X, Y]), axis=0, ddof=1)
1140+
return {'V': V}
1141+
if metric == "mahalanobis" and 'VI' not in kwds:
1142+
if X is Y:
1143+
VI = np.linalg.inv(np.cov(X.T)).T
1144+
else:
1145+
VI = np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T
1146+
return {'VI': VI}
1147+
return {}
1148+
1149+
11321150
def pairwise_distances_chunked(X, Y=None, reduce_func=None,
11331151
metric='euclidean', n_jobs=None,
11341152
working_memory=None, **kwds):
@@ -1264,18 +1282,9 @@ def pairwise_distances_chunked(X, Y=None, reduce_func=None,
12641282
working_memory=working_memory)
12651283
slices = gen_batches(n_samples_X, chunk_n_rows)
12661284

1267-
if metric == "seuclidean" and 'V' not in kwds:
1268-
if X is Y:
1269-
V = np.var(X, axis=0, ddof=1)
1270-
else:
1271-
V = np.var(np.vstack([X, Y]), axis=0, ddof=1)
1272-
kwds.update({'V': V})
1273-
elif metric == "mahalanobis" and 'VI' not in kwds:
1274-
if X is Y:
1275-
VI = np.linalg.inv(np.cov(X.T)).T
1276-
else:
1277-
VI = np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T
1278-
kwds.update({'VI': VI})
1285+
# precompute data-derived metric params
1286+
params = _precompute_metric_params(X, Y, metric=metric, **kwds)
1287+
kwds.update(**params)
12791288

12801289
for sl in slices:
12811290
if sl.start == 0 and sl.stop == n_samples_X:
@@ -1408,18 +1417,9 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=None, **kwds):
14081417
dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else None
14091418
X, Y = check_pairwise_arrays(X, Y, dtype=dtype)
14101419

1411-
if metric == "seuclidean" and 'V' not in kwds:
1412-
if X is Y:
1413-
V = np.var(X, axis=0, ddof=1)
1414-
else:
1415-
V = np.var(np.vstack([X, Y]), axis=0, ddof=1)
1416-
kwds.update({'V': V})
1417-
elif metric == "mahalanobis" and 'VI' not in kwds:
1418-
if X is Y:
1419-
VI = np.linalg.inv(np.cov(X.T)).T
1420-
else:
1421-
VI = np.linalg.inv(np.cov(np.vstack([X, Y]).T)).T
1422-
kwds.update({'VI': VI})
1420+
# precompute data-derived metric params
1421+
params = _precompute_metric_params(X, Y, metric=metric, **kwds)
1422+
kwds.update(**params)
14231423

14241424
if effective_n_jobs(n_jobs) == 1 and X is Y:
14251425
return distance.squareform(distance.pdist(X, metric=metric,

0 commit comments

Comments
 (0)
0