From c8f00fc0e0dc980520f32fd306e80fc43e6f2f3c Mon Sep 17 00:00:00 2001 From: Michael Michaelides Date: Sun, 24 Jan 2016 20:29:26 +0000 Subject: [PATCH 1/2] Extend MDS to out-of-sample points. For Euclidean metric this is equivalent to PCA transfor m. The method introduces errors as new points are projected, compared to a new projection o f all points. See Bengio, Yoshua, et al. "Out-of-sample extensions for lle, isomap, mds, eigenmaps, and spectral clustering." Advances in neural information processing systems 16 (2004): 177-184. --- sklearn/manifold/mds.py | 151 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 143 insertions(+), 8 deletions(-) diff --git a/sklearn/manifold/mds.py b/sklearn/manifold/mds.py index 0ab8bd68ea9af..3d8cdbfec6b1c 100644 --- a/sklearn/manifold/mds.py +++ b/sklearn/manifold/mds.py @@ -6,6 +6,7 @@ # Licence: BSD import numpy as np +from numpy import linalg as la import warnings @@ -341,10 +342,15 @@ class MDS(BaseEstimator): "Multidimensional scaling by optimizing goodness of fit to a nonmetric hypothesis" Kruskal, J. Psychometrika, 29, (1964) + "Out-of-sample extensions for lle, isomap, mds, eigenmaps, and spectral + clustering." Bengio, Y. et al. Advances in neural information processing + systems 16 (2004): 177-184. + """ def __init__(self, n_components=2, metric=True, n_init=4, max_iter=300, verbose=0, eps=1e-3, n_jobs=1, - random_state=None, dissimilarity="euclidean"): + random_state=None, dissimilarity="euclidean", + extendible=False): self.n_components = n_components self.dissimilarity = dissimilarity self.metric = metric @@ -354,43 +360,123 @@ def __init__(self, n_components=2, metric=True, n_init=4, self.verbose = verbose self.n_jobs = n_jobs self.random_state = random_state + self.extendible = extendible @property def _pairwise(self): return self.kernel == "precomputed" def fit(self, X, y=None, init=None): - """ - Computes the position of the points in the embedding space + """Fit the model on X Parameters ---------- X : array, shape=[n_samples, n_features], or [n_samples, n_samples] \ if dissimilarity='precomputed' - Input data. init : {None or ndarray, shape (n_samples,)}, optional If None, randomly chooses the initial configuration if ndarray, initialize the SMACOF algorithm with this array. """ - self.fit_transform(X, init=init) + if not self.extendible: + self.fit_transform(X, init=init) + else: + if self.dissimilarity == 'precomputed': + D = X + elif self.dissimilarity == 'euclidean': + D = euclidean_distances(X) + else: + raise ValueError("Proximity must be 'precomputed' or" + "'euclidean'." + " Got %s instead" % str(self.dissimilarity)) + return self + + # Normalising similarities + K = np.zeros(np.shape(D)) + n = len(D) + D_sq = np.square(D) + P = np.eye(n) - 1 / n * np.ones((n, n)) + K = -0.5 * np.dot(np.dot(P, D_sq), P) + + # Sorting e-vectors and e-values according to e-val + e_vals, e_vecs = la.eigh(K) + ind_sort = np.argsort(e_vals)[::-1] + self.e_vecs = e_vecs[:, ind_sort] + self.e_vals = e_vals[ind_sort] return self def fit_transform(self, X, y=None, init=None): - """ - Fit the data from X, and returns the embedded coordinates + """Fit the data from X, and returns the embedded coordinates Parameters ---------- X : array, shape=[n_samples, n_features], or [n_samples, n_samples] \ if dissimilarity='precomputed' - Input data. init : {None or ndarray, shape (n_samples,)}, optional + Should only be used with the non-extendible MDS (SMACOF). If None, randomly chooses the initial configuration if ndarray, initialize the SMACOF algorithm with this array. """ + if self.extendible: + if init is not None: + raise ValueError("Init is only for the non-extendible MDS.") + ret = self._fit_transform_ext(X) + else: + ret = self._fit_transform(X, init) + return ret + + def transform(self, X, X_train=None): + """Apply the transformation on X + + If dissimilarity is Euclidean, apply the transformation on X. + If dissimilarity is precomputed, X is the similarity matrix to be used + between new (out-of-sample) points with old ones. + The new points (X if Euclidean, or with X similarity matrix if + precomputed) are projected in the same space as the training set. + + Parameters + ---------- + X : array, shape [n_samples, n_features], or \ + [n_samples, n_train_samples] if dissimilarity='precomputed' + New data, where n_samples is the number of samples + and n_features is the number of features for "euclidean" + dissimilarity. Else, similarity matrix (e.g. Euclidean distances + between new and training points). + + NB: similarity matrix has to be centered, use the + make_euclidean_similarities function to create it. + + X_train : array, shape [n_train_samples, n_features] \ + if dissimilarity='euclidean' + Training data for Euclidean case. + + Returns + ------- + X_new : array-like, shape (n_samples, n_components) + + """ + + if not self.extendible: + raise ValueError("Method only available if extendible is True.") + if self.dissimilarity == 'precomputed': + D_new = X + elif self.dissimilarity == 'euclidean': + if X_train is None: + raise ValueError("Euclidean requires X_train, the training " + "points.") + else: + D_aX = euclidean_distances(X, X_train) + D_XX = euclidean_distances(X_train, X_train) + D_new = self.center_similarities(D_aX, D_XX) + else: + raise ValueError("Dissimilarity not set properly: 'precomputed' " + "and 'euclidean' allowed.") + X_new = self._mds_project(D_new, k=self.n_components) + return X_new + + def _fit_transform(self, X, init=None): X = check_array(X) if X.shape[0] == X.shape[1] and self.dissimilarity != "precomputed": warnings.warn("The MDS API has changed. ``fit`` now constructs an" @@ -414,3 +500,52 @@ def fit_transform(self, X, y=None, init=None): return_n_iter=True) return self.embedding_ + + def _fit_transform_ext(self, X): + self.fit(X) + X_new = np.dot(self.e_vecs[:, :self.n_components], + np.diag(np.sqrt(self.e_vals[:self.n_components]))) + return X_new + + def center_similarities(self, D_aX, D_XX): + """Centers similarities D_aX around D_XX + + Parameters + ---------- + D_aX : array, shape=[n_new_samples, n_train_samples] + Dissimilarity matrix of new and training data. + D_XX : array, shape=[n_train_samples, n_train_samples] + Dissimilarity matrix of training data. + + Returns + ------- + new_similarities : array-like, shape=[n_new_samples, n_train_samples] + """ + D_aX = np.square(D_aX) + D_XX = np.square(D_XX) + N = len(D_XX) + M = len(D_aX) + I_NN = np.ones((N, N)) + I_MN = np.ones((M, N)) + Exp_XX = np.sum(D_XX) / N ** 2 + + new_similarities = -0.5 * (D_aX - (np.dot(D_aX, I_NN) + + np.dot(I_MN, D_XX) + ) / N + Exp_XX) + return new_similarities + + def _mds_project(self, new_similarities, k=None): + if k is None: + k = self.n_components + + e_projections = np.zeros((len(new_similarities), k)) + + for i in range(len(new_similarities)): + for j in range(k): + e_projections[i, j] = ((np.dot(self.e_vecs[:, j], + new_similarities[i]) / + np.sqrt(self.e_vals[j]))) + e_projections = np.dot(new_similarities, + np.dot(self.e_vecs[:, :k], + np.diag(1/np.sqrt(self.e_vals[:k])))) + return e_projections From aab1d93e662b1ece2746eacf888734b37d11f043 Mon Sep 17 00:00:00 2001 From: Michael Michaelides Date: Sun, 24 Jan 2016 20:39:16 +0000 Subject: [PATCH 2/2] Added test for extendible MDS. --- sklearn/manifold/tests/test_mds.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn/manifold/tests/test_mds.py b/sklearn/manifold/tests/test_mds.py index a078afd8ca21a..96fb2379062a7 100644 --- a/sklearn/manifold/tests/test_mds.py +++ b/sklearn/manifold/tests/test_mds.py @@ -59,3 +59,9 @@ def test_MDS(): [4, 2, 1, 0]]) mds_clf = mds.MDS(metric=False, n_jobs=3, dissimilarity="precomputed") mds_clf.fit(sim) + # Testing for extendible MDS + mds_clf = mds.MDS(dissimilarity="euclidean", extendible=True) + mds_clf.fit(sim) + sim2 = np.array([[3, 1, 1, 2], + [4, 1, 2, 2]]) + mds_clf.transform(sim2, sim)