@@ -1129,6 +1129,24 @@ def _check_chunk_size(reduced, chunk_size):
1129
1129
chunk_size ))
1130
1130
1131
1131
1132
+ def _precompute_metric_params (X , Y , metric = None , ** kwds ):
1133
+ """Precompute data-derived metric parameters if not provided
1134
+ """
1135
+ if metric == "seuclidean" and 'V' not in kwds :
1136
+ if X is Y :
1137
+ V = np .var (X , axis = 0 , ddof = 1 )
1138
+ else :
1139
+ V = np .var (np .vstack ([X , Y ]), axis = 0 , ddof = 1 )
1140
+ return {'V' : V }
1141
+ if metric == "mahalanobis" and 'VI' not in kwds :
1142
+ if X is Y :
1143
+ VI = np .linalg .inv (np .cov (X .T )).T
1144
+ else :
1145
+ VI = np .linalg .inv (np .cov (np .vstack ([X , Y ]).T )).T
1146
+ return {'VI' : VI }
1147
+ return {}
1148
+
1149
+
1132
1150
def pairwise_distances_chunked (X , Y = None , reduce_func = None ,
1133
1151
metric = 'euclidean' , n_jobs = None ,
1134
1152
working_memory = None , ** kwds ):
@@ -1264,18 +1282,9 @@ def pairwise_distances_chunked(X, Y=None, reduce_func=None,
1264
1282
working_memory = working_memory )
1265
1283
slices = gen_batches (n_samples_X , chunk_n_rows )
1266
1284
1267
- if metric == "seuclidean" and 'V' not in kwds :
1268
- if X is Y :
1269
- V = np .var (X , axis = 0 , ddof = 1 )
1270
- else :
1271
- V = np .var (np .vstack ([X , Y ]), axis = 0 , ddof = 1 )
1272
- kwds .update ({'V' : V })
1273
- elif metric == "mahalanobis" and 'VI' not in kwds :
1274
- if X is Y :
1275
- VI = np .linalg .inv (np .cov (X .T )).T
1276
- else :
1277
- VI = np .linalg .inv (np .cov (np .vstack ([X , Y ]).T )).T
1278
- kwds .update ({'VI' : VI })
1285
+ # precompute data-derived metric params
1286
+ params = _precompute_metric_params (X , Y , metric = metric , ** kwds )
1287
+ kwds .update (** params )
1279
1288
1280
1289
for sl in slices :
1281
1290
if sl .start == 0 and sl .stop == n_samples_X :
@@ -1408,18 +1417,9 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=None, **kwds):
1408
1417
dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else None
1409
1418
X , Y = check_pairwise_arrays (X , Y , dtype = dtype )
1410
1419
1411
- if metric == "seuclidean" and 'V' not in kwds :
1412
- if X is Y :
1413
- V = np .var (X , axis = 0 , ddof = 1 )
1414
- else :
1415
- V = np .var (np .vstack ([X , Y ]), axis = 0 , ddof = 1 )
1416
- kwds .update ({'V' : V })
1417
- elif metric == "mahalanobis" and 'VI' not in kwds :
1418
- if X is Y :
1419
- VI = np .linalg .inv (np .cov (X .T )).T
1420
- else :
1421
- VI = np .linalg .inv (np .cov (np .vstack ([X , Y ]).T )).T
1422
- kwds .update ({'VI' : VI })
1420
+ # precompute data-derived metric params
1421
+ params = _precompute_metric_params (X , Y , metric = metric , ** kwds )
1422
+ kwds .update (** params )
1423
1423
1424
1424
if effective_n_jobs (n_jobs ) == 1 and X is Y :
1425
1425
return distance .squareform (distance .pdist (X , metric = metric ,
0 commit comments