diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index c6838556d50ad..c8f293cfd5e26 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1346,6 +1346,7 @@ Model validation neighbors.NearestCentroid neighbors.NearestNeighbors neighbors.NeighborhoodComponentsAnalysis + neighbors.LargeMarginNearestNeighbor .. autosummary:: :toctree: generated/ diff --git a/doc/modules/decomposition.rst b/doc/modules/decomposition.rst index 293f31dacd091..6d8ea0016de92 100644 --- a/doc/modules/decomposition.rst +++ b/doc/modules/decomposition.rst @@ -1085,4 +1085,5 @@ when data can be fetched sequentially. H. F. Kaiser, 1958 See also :ref:`nca_dim_reduction` for dimensionality reduction with -Neighborhood Components Analysis. +Neighborhood Components Analysis or :ref:`lmnn_dim_reduction` for +dimensionality reduction with Large Margin Nearest Neighbor. diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst index b362e5e69f7ee..4b165c230e036 100644 --- a/doc/modules/neighbors.rst +++ b/doc/modules/neighbors.rst @@ -843,3 +843,214 @@ added space complexity in the operation. `Wikipedia entry on Neighborhood Components Analysis `_ + + +.. _lmnn: + +Large Margin Nearest Neighbor +============================= + +.. sectionauthor:: John Chiotellis + +Large Margin Nearest Neighbor (LMNN, :class:`LargeMarginNearestNeighbor`) is +a metric learning algorithm which aims to improve the accuracy of +nearest neighbors classification compared to the standard Euclidean distance. + +.. |lmnn_illustration_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_illustration_001.png + :target: ../auto_examples/neighbors/plot_lmnn_illustration.html + :scale: 50 + +.. |lmnn_illustration_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_illustration_002.png + :target: ../auto_examples/neighbors/plot_lmnn_illustration.html + :scale: 50 + +.. centered:: |lmnn_illustration_1| |lmnn_illustration_2| + + +For each training sample, the algorithm fixes :math:`k` "target neighbors", +namely the :math:`k`-nearest training samples (as measured by the Euclidean +distance) that share the same label. Given these target neighbors, LMNN +learns a linear transformation of the data by optimizing a trade-off between +two goals. The first one is to make each (transformed) point closer to its +target neighbors than to any differently-labeled point by a large margin, +thereby enclosing the target neighbors in a sphere around the reference +sample. Data samples from different classes that violate this margin are +called "impostors". The second goal is to minimize the distances of each +sample to its target neighbors, which can be seen as a form of regularization. + +Classification +-------------- + +Combined with a nearest neighbors classifier (:class:`KNeighborsClassifier`), +this method is attractive for classification because it can naturally +handle multi-class problems without any increase in the model size, and only +a single parameter (``n_neighbors``) has to be selected by the user before +training. + +Large Margin Nearest Neighbor classification has been shown to work well in +practice for data sets of varying size and difficulty. In contrast to +related methods such as Linear Discriminant Analysis, LMNN does not make any +assumptions about the class distributions. The nearest neighbor classification +can naturally produce highly irregular decision boundaries. + +To use this model for classification, one needs to combine a :class:`LargeMarginNearestNeighbor` +instance that learns the optimal transformation with a :class:`KNeighborsClassifier` +instance that performs the classification in the embedded space. Here is an +example using the two classes: + + >>> from sklearn.neighbors import LargeMarginNearestNeighbor + >>> from sklearn.neighbors import KNeighborsClassifier + >>> from sklearn.datasets import load_iris + >>> from sklearn.model_selection import train_test_split + >>> X, y = load_iris(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, + ... stratify=y, test_size=0.7, random_state=42) + >>> lmnn = LargeMarginNearestNeighbor(n_neighbors=3, random_state=42) + >>> lmnn.fit(X_train, y_train) + LargeMarginNearestNeighbor(...) + >>> # Apply the learned transformation when using KNeighborsClassifier + >>> knn = KNeighborsClassifier(n_neighbors=3) + >>> knn.fit(lmnn.transform(X_train), y_train) + KNeighborsClassifier(...) + >>> print(knn.score(lmnn.transform(X_test), y_test)) + 0.971428... + +Alternatively, one can create a :class:`sklearn.pipeline.Pipeline` instance +that automatically applies the transformation when fitting or predicting: + + >>> from sklearn.pipeline import Pipeline + >>> lmnn = LargeMarginNearestNeighbor(n_neighbors=3, random_state=42) + >>> knn = KNeighborsClassifier(n_neighbors=3) + >>> lmnn_pipe = Pipeline([('lmnn', lmnn), ('knn', knn)]) + >>> lmnn_pipe.fit(X_train, y_train) + Pipeline(...) + >>> print(lmnn_pipe.score(X_test, y_test)) + 0.971428... + +.. |lmnn_classification_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_classification_001.png + :target: ../auto_examples/neighbors/plot_lmnn_classification.html + :scale: 50 + +.. |lmnn_classification_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_classification_002.png + :target: ../auto_examples/neighbors/plot_lmnn_classification.html + :scale: 50 + +.. centered:: |lmnn_classification_1| |lmnn_classification_2| + + +The plot shows decision boundaries for nearest neighbor classification and +large margin nearest neighbor classification. + +.. _lmnn_dim_reduction: + +Dimensionality reduction +------------------------ + +:class:`LargeMarginNearestNeighbor` can be used to perform supervised +dimensionality reduction. The input data are mapped to a linear subspace +consisting of the directions which minimize the LMNN objective. Unlike +unsupervised methods which aim to maximize the uncorrelatedness (PCA) or even +independence (ICA) of the components, LMNN aims to find components that +maximize the nearest neighbors classification accuracy of the transformed +inputs. The desired output dimensionality can be set using the parameter +``n_components``. For instance, the following shows a comparison of +dimensionality reduction with Principal Component Analysis (:class:`sklearn +.decomposition.PCA`), Linear Discriminant Analysis (:class:`sklearn +.discriminant_analysis.LinearDiscriminantAnalysis`) and Large Margin Nearest +Neighbor (:class:`LargeMarginNearestNeighbor`) on the Olivetti dataset, a +dataset with size :math:`n_{samples} = 400` and :math:`n_{features} = 64 \times 64 = 4096`. +The data set is splitted in a training and test set of equal size. For +evaluation the 3-nearest neighbor classification accuracy is computed on the +2-dimensional embedding found by each method. Each data sample belongs to one +of 40 classes. + +.. |lmnn_dim_reduction_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_dim_reduction_001.png + :target: ../auto_examples/neighbors/plot_lmnn_dim_reduction.html + :width: 32% + +.. |lmnn_dim_reduction_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_dim_reduction_002.png + :target: ../auto_examples/neighbors/plot_lmnn_dim_reduction.html + :width: 32% + +.. |lmnn_dim_reduction_3| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_dim_reduction_003.png + :target: ../auto_examples/neighbors/plot_lmnn_dim_reduction.html + :width: 32% + +.. centered:: |lmnn_dim_reduction_1| |lmnn_dim_reduction_2| |lmnn_dim_reduction_3| + + +Mathematical formulation +------------------------ + +LMNN learns a linear transformation matrix :math:`L` of +size ``(n_components, n_features)``. The objective function consists of +two competing terms, the pull loss that pulls target neighbors closer to +their reference sample and the push loss that pushes impostors away: + +.. math:: + \varepsilon_{\text{pull}} (L) = \sum_{i, j \rightsquigarrow i} ||L(x_i - x_j)||^2, +.. math:: + \varepsilon_{\text{push}} (L) = \sum_{i, j \rightsquigarrow i} + \sum_{l} (1 - y_{il}) [1 + || L(x_i - x_j)||^2 - || L + (x_i - x_l)||^2]_+, + +where :math:`y_{il} = 1` if :math:`y_i = y_l` and :math:`0` otherwise, +:math:`[x]_+ = \max(0, x)` is the hinge loss, and :math:`j \rightsquigarrow i` +means that the :math:`j^{th}` sample is a target neighbor of the +:math:`i^{th}` sample. + +LMNN solves the following (nonconvex) minimization problem: + +.. math:: + \min_L \varepsilon(L) = (1 - \mu) \varepsilon_{\text{pull}} (L) + + \mu \varepsilon_{\text{push}} (L) \text{, } \quad \mu \in [0,1]. + +The parameter :math:`\mu` (``push_loss_weight``) calibrates the trade-off +between penalizing large distances to target neighbors and penalizing margin +violations by impostors. In practice, the two terms are usually weighted +equally (:math:`\mu = 0.5`). + + +Mahalanobis distance +^^^^^^^^^^^^^^^^^^^^ + +LMNN can be seen as learning a (squared) Mahalanobis distance metric: + +.. math:: + || L(x_i - x_j)||^2 = (x_i - x_j)^TM(x_i - x_j), + +where :math:`M = L^T L` is a symmetric positive semi-definite matrix of size +``(n_features, n_features)``. The objective function of LMNN can be +rewritten and solved with respect to :math:`M` directly. This results in a +convex but constrained problem (since :math:`M` must be symmetric positive +semi-definite). See the journal paper in the References for more details. + + +Implementation +-------------- + +This implementation follows closely the MATLAB implementation found at +https://bitbucket.org/mlcircus/lmnn which solves the unconstrained problem. +It finds a linear transformation :math:`L` by optimization with L-BFGS instead +of solving the constrained problem that finds the globally optimal distance +metric. Different from the paper, the problem solved by this implementation is +with the *squared* hinge loss (to make the problem differentiable). + +See the examples below and the doc string of :meth:`LargeMarginNearestNeighbor.fit` +for further information. + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_neighbors_plot_lmnn_classification.py` + * :ref:`sphx_glr_auto_examples_neighbors_plot_lmnn_dim_reduction.py` + + +.. topic:: References: + + * `"Distance Metric Learning for Large Margin Nearest Neighbor Classification" + `_, + Weinberger, Kilian Q., and Lawrence K. Saul, Journal of Machine Learning + Research, Vol. 10, Feb. 2009, pp. 207-244. + + * `Wikipedia entry on Large Margin Nearest Neighbor + `_ diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 6ece2f16b6e93..986edfe03b452 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -348,13 +348,6 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. -- |API| The option for using the squared error via ``loss`` and - ``criterion`` parameters was made more consistent. The preferred way is by - setting the value to `"squared_error"`. Old option names are still valid, - produce the same models, but are deprecated and will be removed in version - 1.2. - :pr:`19310` by :user:`Christian Lorentzen `. - - For :class:`ensemble.ExtraTreesRegressor`, `criterion="mse"` is deprecated, use `"squared_error"` instead which is now the default. @@ -1011,6 +1004,8 @@ Changelog Use ``var_`` instead. :pr:`18842` by :user:`Hong Shao Yang `. +<<<<<<< HEAD +======= :mod:`sklearn.neighbors` ........................ @@ -1212,6 +1207,7 @@ Changelog - |API| Fixed several bugs in :func:`utils.graph.graph_shortest_path`, which is now deprecated. Use `scipy.sparse.csgraph.shortest_path` instead. :pr:`20531` by `Tom Dupre la Tour`_. +>>>>>>> origin/main Code and Documentation Contributors ----------------------------------- diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index ba51e28229462..a5cefe0810cd7 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -254,6 +254,11 @@ Changelog instead of failing with a low-level error message at predict-time. :pr:`23874` by :user:`Juan Gomez <2357juan>`. +- |MajorFeature| A metric learning algorithm: + :class:`neighbors.LargeMarginNearestNeighbor`, which implements the + Large Margin Nearest Neighbor algorithm described in Weinberger et al. + (2006). :pr:`8602` by :user:`John Chiotellis `. + :mod:`sklearn.svm` .................. diff --git a/examples/neighbors/plot_lmnn_classification.py b/examples/neighbors/plot_lmnn_classification.py new file mode 100644 index 0000000000000..c7ded9d6f55d3 --- /dev/null +++ b/examples/neighbors/plot_lmnn_classification.py @@ -0,0 +1,85 @@ +""" +========================================================================== +Comparing Nearest Neighbors with and without Large Margin Nearest Neighbor +========================================================================== + +This example compares nearest neighbors classification with and without +Large Margin Nearest Neighbor. + +It will plot the decision boundaries for each class determined by a simple +Nearest Neighbors classifier against the decision boundaries determined by a +Large Margin Nearest Neighbor classifier. The latter aims to find a distance +metric that maximizes the nearest neighbor classification accuracy on a given +training set. +""" + +# Author: John Chiotellis +# License: BSD 3 clause + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.neighbors import KNeighborsClassifier, LargeMarginNearestNeighbor +from sklearn.pipeline import Pipeline + + +print(__doc__) + +n_neighbors = 3 + +# import some data to play with +iris = datasets.load_iris() + +# we only take the first two features. We could avoid this ugly +# slicing by using a two-dim dataset +X = iris.data[:, :2] +y = iris.target + +X_train, X_test, y_train, y_test = \ + train_test_split(X, y, stratify=y, test_size=0.7, random_state=42) + +h = .01 # step size in the mesh + +# Create color maps +cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) +cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) + +names = ['K-Nearest Neighbors', 'Large Margin Nearest Neighbor'] + +classifiers = [KNeighborsClassifier(n_neighbors=n_neighbors), + Pipeline([('lmnn', LargeMarginNearestNeighbor( + n_neighbors=n_neighbors, random_state=42)), + ('knn', KNeighborsClassifier(n_neighbors)) + ]) + ] + +x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 +y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 +xx, yy = np.meshgrid(np.arange(x_min, x_max, h), + np.arange(y_min, y_max, h)) + +for name, clf in zip(names, classifiers): + + clf.fit(X_train, y_train) + score = clf.score(X_test, y_test) + + # Plot the decision boundary. For that, we will assign a color to each + # point in the mesh [x_min, x_max]x[y_min, y_max]. + Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) + + # Put the result into a color plot + Z = Z.reshape(xx.shape) + plt.figure() + plt.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=.8, shading='auto') + + # Plot also the training and testing points + plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor='k', s=20) + plt.xlim(xx.min(), xx.max()) + plt.ylim(yy.min(), yy.max()) + plt.title("{} (k = {})".format(name, n_neighbors)) + plt.text(0.9, 0.1, '{:.2f}'.format(score), size=15, + ha='center', va='center', transform=plt.gca().transAxes) + +plt.show() diff --git a/examples/neighbors/plot_lmnn_dim_reduction.py b/examples/neighbors/plot_lmnn_dim_reduction.py new file mode 100644 index 0000000000000..54da129f91d18 --- /dev/null +++ b/examples/neighbors/plot_lmnn_dim_reduction.py @@ -0,0 +1,102 @@ +""" +=========================================================== +Dimensionality Reduction with Large Margin Nearest Neighbor +=========================================================== + +This example compares different (linear) dimensionality reduction methods +applied on the Olivetti Faces data set. The data set contains ten different +images of each of 40 distinct persons. For some subjects, the images were +taken at different times, varying the lighting, facial expressions (open / +closed eyes, smiling / not smiling) and facial details (glasses / no glasses). +Each image of dimensions 64x64 is reduced to a two-dimensional data point. + +Principal Component Analysis (PCA) applied to this data identifies the +combination of attributes (principal components, or directions in the +feature space) that account for the most variance in the data. Here we +plot the different samples on the 2 first principal components. + +Linear Discriminant Analysis (LDA) tries to identify attributes that +account for the most variance *between classes*. In particular, +LDA, in contrast to PCA, is a supervised method, using known class labels. + +Large Margin Nearest Neighbor (LMNN) tries to find a feature space such that +the nearest neighbors classification accuracy is maximized. Like LDA, it is a +supervised method. + +In this example, nearest neighbors classification is used for supervised +prediction based on the embeddings learned from each of the three +decomposition approaches, PCA, LDA and LMNN. + +One can see that LMNN enforces a clustering of the data that is visually +meaningful even after the large dimensionality reduction. +""" + + +# Author: John Chiotellis +# License: BSD 3 clause + +import numpy as np +import matplotlib.pyplot as plt +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.neighbors import LargeMarginNearestNeighbor, KNeighborsClassifier + + +print(__doc__) + +n_neighbors = 3 +random_state = 0 + +# Load Olivetti Faces dataset +faces = datasets.fetch_olivetti_faces() +X, y = faces.data, faces.target + +# Split into train/test +X_train, X_test, y_train, y_test = \ + train_test_split(X, y, test_size=0.5, stratify=y, + random_state=random_state) + +dim = len(X[0]) +n_classes = len(np.unique(y)) + +# Reduce dimension to 2 with PCA +pca = PCA(n_components=2, random_state=random_state) + +# Reduce dimension to 2 with LinearDiscriminantAnalysis +lda = LinearDiscriminantAnalysis(n_components=2) + +# Reduce dimension to 2 with LargeMarginNearestNeighbor +lmnn = LargeMarginNearestNeighbor(n_neighbors=n_neighbors, n_components=2, + tol=0.1, verbose=1, + random_state=random_state) + +# Use a nearest neighbor classifier to evaluate the methods +knn = KNeighborsClassifier(n_neighbors=n_neighbors) + +# Make a list of the methods to be compared +dim_reduction_methods = [('PCA', pca), ('LDA', lda), ('LMNN', lmnn)] + +for i, (name, model) in enumerate(dim_reduction_methods): + plt.figure() + + # Fit the method's model + model.fit(X_train, y_train) + + # Fit a nearest neighbor classifier on the embedded training set + knn.fit(model.transform(X_train), y_train) + + # Compute the nearest neighbor accuracy on the embedded test set + acc_knn = knn.score(model.transform(X_test), y_test) + + # Embed the data set in 2 dimensions using the fitted model + X_embedded = model.transform(X) + + # Plot the embedding and show the evaluation score + plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y) + plt.title("{}, KNN (k={})".format(name, n_neighbors)) + plt.text(0.9, 0.1, '{:.2f}'.format(acc_knn), size=15, + ha='center', va='center', transform=plt.gca().transAxes) + +plt.show() diff --git a/examples/neighbors/plot_lmnn_illustration.py b/examples/neighbors/plot_lmnn_illustration.py new file mode 100644 index 0000000000000..bb4c29c75b211 --- /dev/null +++ b/examples/neighbors/plot_lmnn_illustration.py @@ -0,0 +1,153 @@ +""" +========================================== +Large Margin Nearest Neighbor Illustration +========================================== + +This example illustrates the goal of learning a distance metric that maximizes +the nearest neighbors classification accuracy. The example is solely for +illustration purposes. Please refer to the :ref:`User Guide ` for +more information. +""" + +# Author: John Chiotellis +# License: BSD 3 clause + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.datasets import make_classification +from sklearn.neighbors import LargeMarginNearestNeighbor +from sklearn.neighbors._lmnn import _select_target_neighbors + + +print(__doc__) + +n_neighbors = 1 +random_state = 0 + +# Create a tiny data set of 9 samples from 3 classes +X, y = make_classification(n_samples=9, n_features=2, n_informative=2, + n_redundant=0, n_classes=3, n_clusters_per_class=1, + class_sep=1.0, random_state=random_state) + +# Spread out the data so that a margin of 1 is visible +X *= 6 + +# Find the target neighbors +target_neighbors = _select_target_neighbors(X, y, n_neighbors) + +# Plot the points in the original space +plt.figure() +ax = plt.gca() + +# Draw the graph nodes +ax.scatter(X[:, 0], X[:, 1], s=300, c=y, alpha=0.4) +for i in range(X.shape[0]): + ax.text(X[i, 0], X[i, 1], str(i), va='center', ha='center') + +# Annotate the reference sample +ref_pos = X[3] +ref_text_pos = ref_pos - np.array([1, 1.5]) +ax.annotate('reference', xy=ref_pos, xytext=ref_text_pos, color='k', + style='italic') + +# Annotate the target neighbor relationship +tn_pos = X[target_neighbors[3]][0] +tn_text_pos = tn_pos + np.array([-4.5, 1]) +ax.annotate('target neighbor', xy=tn_pos, xytext=tn_text_pos, color='k', + style='italic') + +# Draw a circle with the radius touching the target neighbor +tn_edge = tn_pos - ref_pos +inner_radius = np.sqrt(np.dot(tn_edge, tn_edge)) +tn_circle = plt.Circle(ref_pos, inner_radius, color='b', linestyle='dashed', + fill=False, alpha=0.6) +ax.add_artist(tn_circle) + +# Draw an outer circle indicating the margin +margin = 1. +margin_circle = plt.Circle(ref_pos, inner_radius + margin, color='b', + linestyle='dashed', fill=False, alpha=0.6) +ax.add_artist(margin_circle) + +# Annotate the margin +margin_color = 'orange' + + +def fill_between_circles(ax, center, radii, color=margin_color): + n = 50 + theta = np.linspace(0, 2 * np.pi, n, endpoint=True) + xs = np.outer(radii, np.cos(theta)) + center[0] + ys = np.outer(radii, np.sin(theta)) + center[1] + + # in order to have a closed area, the circles + # should be traversed in opposite directions + xs[1, :] = xs[1, ::-1] + ys[1, :] = ys[1, ::-1] + + ax.fill(np.ravel(xs), np.ravel(ys), facecolor=color, alpha=0.1) + + +theta_ref = -3 * np.pi / 4 +vec_ref = np.array([np.cos(theta_ref), np.sin(theta_ref)]) +p_inner = ref_pos + vec_ref * inner_radius +p_outer = p_inner + vec_ref * margin +margin_text_pos = np.array([-13, -1]) +middle = (p_inner + p_outer) / 2 +ax.annotate('margin', xy=middle, xytext=margin_text_pos, style='italic', + arrowprops=dict(facecolor='black', arrowstyle='->', + connectionstyle="arc3,rad=-0.3")) +fill_between_circles(ax, X[3], [inner_radius, inner_radius + margin]) + +# Annotate the impostors (1, 4, 5, 7) +imp_centroid = (X[1] + X[4] + X[5] + X[7]) / 4 +imp_arrow_dict = dict(facecolor='black', arrowstyle='->', + connectionstyle="arc3,rad=0.3", shrinkB=9) +ax.annotate('', xy=X[1], xytext=imp_centroid, arrowprops=imp_arrow_dict) +ax.annotate('', xy=X[4], xytext=imp_centroid, arrowprops=imp_arrow_dict) +ax.annotate('', xy=X[5], xytext=imp_centroid, arrowprops=imp_arrow_dict) +ax.annotate('', xy=X[7], xytext=imp_centroid, arrowprops=imp_arrow_dict) +ax.text(imp_centroid[0] - 1, imp_centroid[1] + 1, 'impostors', color='k', + style='italic') + +# Make axes equal so that boundaries are displayed correctly as circles +plt.axis('equal') +ax.set_title("Input space") + + +# Learn an embedding with LargeMarginNearestNeighbor +lmnn = LargeMarginNearestNeighbor(n_neighbors=n_neighbors, max_iter=30, + random_state=random_state) +lmnn = lmnn.fit(X, y) + +# Plot the points after transformation with LargeMarginNearestNeighbor +plt.figure() +ax2 = plt.gca() + +# Get the embedding and find the new nearest neighbors +X_embedded = lmnn.transform(X) + +ax2.scatter(X_embedded[:, 0], X_embedded[:, 1], s=300, c=y, alpha=0.4) +for i in range(len(X)): + ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i), + va='center', ha='center') + +# Draw a circle with the radius touching the target neighbor +tn_edge = X_embedded[3] - X_embedded[target_neighbors[3]][0] +inner_radius = np.sqrt(np.dot(tn_edge, tn_edge)) +tn_circle = plt.Circle(X_embedded[3], inner_radius, color='b', + linestyle='dashed', fill=False, alpha=0.6) +ax2.add_artist(tn_circle) + +# Draw an outer circle indicating the margin +margin_circle = plt.Circle(X_embedded[3], inner_radius + margin, color='b', + linestyle='dashed', fill=False, alpha=0.6) +ax2.add_artist(margin_circle) + +# Fill the margin with color +fill_between_circles(ax2, X_embedded[3], [inner_radius, inner_radius + margin]) + +# Make axes equal so that boundaries are displayed correctly as circles +plt.axis('equal') +ax2.set_title("LMNN embedding") + +plt.show() diff --git a/sklearn/neighbors/__init__.py b/sklearn/neighbors/__init__.py index 12824e9cb684e..aa80bfff34f3a 100644 --- a/sklearn/neighbors/__init__.py +++ b/sklearn/neighbors/__init__.py @@ -14,6 +14,7 @@ from ._nearest_centroid import NearestCentroid from ._kde import KernelDensity from ._lof import LocalOutlierFactor +from ._lmnn import LargeMarginNearestNeighbor from ._nca import NeighborhoodComponentsAnalysis from ._base import sort_graph_by_row_values from ._base import VALID_METRICS, VALID_METRICS_SPARSE @@ -34,6 +35,7 @@ "radius_neighbors_graph", "KernelDensity", "LocalOutlierFactor", + "LargeMarginNearestNeighbor", "NeighborhoodComponentsAnalysis", "sort_graph_by_row_values", "VALID_METRICS", diff --git a/sklearn/neighbors/_lmnn.py b/sklearn/neighbors/_lmnn.py new file mode 100644 index 0000000000000..e156db96e2e37 --- /dev/null +++ b/sklearn/neighbors/_lmnn.py @@ -0,0 +1,1185 @@ +""" +Large Margin Nearest Neighbor Classification +""" + +# Author: John Chiotellis +# License: BSD 3 clause + +import sys +import time +from numbers import Integral, Real +from warnings import warn + +import numpy as np +from scipy.optimize import minimize +from scipy.sparse import csr_matrix, csc_matrix, coo_matrix + +from ..base import BaseEstimator, TransformerMixin +from ..neighbors import NearestNeighbors +from ..decomposition import PCA +from ..exceptions import ConvergenceWarning +from ..utils import gen_batches, get_chunk_n_rows +from ..utils.extmath import ( + _euclidean_distances_without_checks, + row_norms, + safe_sparse_dot, +) +from ..utils.multiclass import check_classification_targets +from ..utils._param_validation import Interval, StrOptions +from ..utils.random import check_random_state +from ..utils.validation import check_is_fitted, check_array, check_X_y + + +class LargeMarginNearestNeighbor(BaseEstimator, TransformerMixin): + r"""Distance metric learning for large margin classification. + + Large Margin Nearest Neighbor (LMNN) is a machine learning algorithm for + metric learning. It learns a linear transformation in a supervised fashion + to improve the classification accuracy of the :math:`k`-nearest neighbors + rule in the transformed space. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.2 + + Parameters + ---------- + n_neighbors : int, default=3 + Number of samples to use as target neighbors for each sample. + + n_components : int, default=None + Preferred dimensionality of the transformed samples. + If `None` it is inferred from `init`. + + init : {"pca", "identity"} or array-like of shape (n_features_a, n_features), \ + default="pca" + Initialization of the linear transformation. The possible options are: + + * "pca": `n_components` many principal components of the inputs passed + to :meth:`fit` will be used to initialize the transformation. If + `n_components` is `None` all components are kept and thus + `n_components = min(n_samples, n_features)`. + * "identity: if `n_components` is strictly smaller than the + dimensionality of the inputs passed to :meth:`fit`, the identity + matrix will be truncated to the first `n_components` rows. + * ndarray: `n_features` must match the dimensionality of the inputs + passed to :meth:`fit` and `n_features_a` must be less than or equal to + that. If `n_components` is not `None`, `n_features_a` must match it. + + warm_start : bool, default=False + If `True` and :meth:`fit` has been called before, the solution of the + previous call to :meth:`fit` is used as the initial linear + transformation (`n_components` and `init` will be ignored). + + max_impostors : int, default=500_000 + Maximum number of impostors to consider per iteration. Impostors are + samples that are too close to a sample with a different label, + thereby violating their margin. In the worst case this will allow + `max_impostors * n_neighbors` constraints to be active. + + neighbors_params : dict, default=None + Parameters to pass to a :class:`~sklearn.neighbors.NearestNeighbors` + instance - apart from `n_neighbors` - that will be used to select the + target neighbors. + + push_loss_weight : float, default=0.5 + A float in (0, 1], weighting the push loss. This is parameter + :math:`\mu` in [1]_. In practice, the objective function will be + normalized so that the push loss has weight 1 and hence the pull loss + has weight `(1 - mu) / mu`. + + impostor_store : {"list", "sparse", "auto"}, default="auto" + * "list": Three lists will be used to store the indices of reference + samples, the indices of their impostors and the squared + distances between each (sample, impostor) pair. + * "sparse": A sparse indicator matrix will be used to store the + (sample, impostor) pairs. The squared distances to the impostors will + be computed twice (once to determine the impostors and once to be + stored), but this option tends to be faster than 'list' as the size + of the data set increases. + * "auto": Will attempt to decide the most appropriate choice of data + structure based on the values passed to :meth:`fit`. + + max_iter : int, default=50 + Maximum number of iterations in the optimization. + + tol : float, default=1e-5 + Convergence tolerance for the optimization. + + callback : callable, default=None + If not None, this function is called after every iteration of the + optimizer, taking as arguments the current solution (flattened + transformation matrix) and the number of iterations. This might be + useful in case one wants to examine or store the transformation + found after each iteration. + + store_opt_result : bool, default=False + If True, the :class:`scipy.optimize.OptimizeResult` object returned by + :meth:`minimize` of :mod:`scipy.optimize` will be stored as attribute + `opt_result_`. + + verbose : int, default=0 + If 0, no progress messages will be printed. + If 1, progress messages will be printed to stdout. + If > 1, progress messages will be printed and the `disp` parameter + of :func:`scipy.optimize.minimize` will be set to `verbose - 2`. + + random_state : int, RandomState instance or None, default=None + Pseudo-random number generator to the sub-sampling of the impostors if + they exceed `max_impostors` and to control the initialization of the + linear transformation if `init="pca"`. + Pass an int for reproducible output across multiple function calls. + See :term:`Glossary `. + + n_jobs : int, default=None + The number of jobs to use for the neighbors search. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. + ``-1`` means using all processors. See :term:`Glossary ` + for more details. Doesn't affect :meth:`fit` method. + + Attributes + ---------- + components_ : ndarray of shape (n_components, n_features) + The linear transformation learned during fitting. + + n_neighbors_ : int + The provided `n_neighbors` is decreased if it is greater than or + equal to minimum of the number of elements in each class. + + n_iter_ : int + Counts the number of iterations performed by the optimizer. + + opt_result_ : scipy.optimize.OptimizeResult + A dictionary of information representing the optimization result. + This is stored only if `store_opt_result=True`. It contains the + following attributes: + + x : ndarray + The solution of the optimization. + success : bool + Whether or not the optimizer exited successfully. + status : int + Termination status of the optimizer. + message : str + Description of the cause of the termination. + fun, jac : ndarray + Values of objective function and its Jacobian. + hess_inv : scipy.sparse.linalg.LinearOperator + the product of a vector with the approximate inverse of the + Hessian of the objective function.. + nfev : int + Number of evaluations of the objective function.. + nit : int + Number of iterations performed by the optimizer. + + Notes + ----- + + .. warning:: + + Exact floating-point reproducibility is generally not guaranteed + (unless special care is taken with library and compiler options). As + a consequence, the transformations computed in 2 identical runs of + LargeMarginNearestNeighbor can differ from each other. This can + happen even before the optimizer is called if initialization with + PCA is used (init='pca'). + + References + ---------- + .. [1] `Weinberger, Kilian Q., and Lawrence K. Saul. + "Distance Metric Learning for Large Margin Nearest Neighbor Classification." + Journal of Machine Learning Research, Vol. 10, Feb. 2009, pp. 207-244. + `_ + + .. [2] `Wikipedia entry on Large Margin Nearest Neighbor + `_ + + Examples + -------- + >>> from sklearn.neighbors import LargeMarginNearestNeighbor + >>> from sklearn.neighbors import KNeighborsClassifier + >>> from sklearn.datasets import load_iris + >>> from sklearn.model_selection import train_test_split + >>> X, y = load_iris(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, + ... stratify=y, test_size=0.7, random_state=42) + >>> lmnn = LargeMarginNearestNeighbor(n_neighbors=3, random_state=42) + >>> lmnn.fit(X_train, y_train) + LargeMarginNearestNeighbor(...) + >>> # Fit and evaluate a simple nearest neighbor classifier for comparison + >>> knn = KNeighborsClassifier(n_neighbors=3) + >>> knn.fit(X_train, y_train) + KNeighborsClassifier(...) + >>> print(knn.score(X_test, y_test)) + 0.93... + >>> # Now fit on the data transformed by the learned transformation + >>> knn.fit(lmnn.transform(X_train), y_train) + KNeighborsClassifier(...) + >>> print(knn.score(lmnn.transform(X_test), y_test)) + 0.97... + """ + + _parameter_constraints = { + "n_neighbors": [Interval(Integral, left=1, right=None, closed="left")], + "n_components": [Interval(Integral, left=1, right=None, closed="left"), None], + "init": [StrOptions({"pca", "identity"}), "array-like"], + "warm_start": ["boolean"], + "max_impostors": [Integral], + "neighbors_params": [dict, None], + "push_loss_weight": [Interval(Real, left=0, right=1, closed="right")], + "impostor_store": [StrOptions({"list", "sparse", "auto"})], + "max_iter": [Interval(Integral, left=1, right=None, closed="left")], + "tol": [Interval(Real, left=0, right=None, closed="left")], + "callback": [callable, None], + "store_opt_result": ["boolean"], + "verbose": ["verbose"], + "random_state": ["random_state"], + "n_jobs": [Integral, None], + } + + def __init__( + self, + n_neighbors=3, + n_components=None, + *, + init="pca", + warm_start=False, + max_impostors=500_000, + neighbors_params=None, + push_loss_weight=0.5, + impostor_store="auto", + max_iter=50, + tol=1e-5, + callback=None, + store_opt_result=False, + verbose=0, + random_state=None, + n_jobs=None, + ): + + # Parameters + self.n_neighbors = n_neighbors + self.n_components = n_components + self.init = init + self.warm_start = warm_start + self.max_impostors = max_impostors + self.neighbors_params = neighbors_params + self.push_loss_weight = push_loss_weight + self.impostor_store = impostor_store + self.max_iter = max_iter + self.tol = tol + self.callback = callback + self.store_opt_result = store_opt_result + self.verbose = verbose + self.random_state = random_state + self.n_jobs = n_jobs + + def fit(self, X, y): + """Fit the model according to the given training data. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The training samples. + + y : array-like of shape (n_samples,) + The corresponding training labels. + + Returns + ------- + self : object + returns a trained LargeMarginNearestNeighbor model. + """ + + self._validate_params() + + # Check that the inputs are consistent with the parameters + X, y, classes, initial_transformation = self._validate_params_inner(X, y) + + # Initialize the random generator + self.random_state_ = check_random_state(self.random_state) + + # Measure the total training time + t_train = time.time() + + # Initialize the linear transformation + transformation = self._initialize(X, initial_transformation) + + # Find the target neighbors + target_neighbors = self._select_target_neighbors_wrapper(X, y, classes) + + # Compute the gradient part contributed by the target neighbors + grad_static = self._compute_grad_static(X, target_neighbors) + + # Compute the pull loss coefficient + pull_loss_weight = (1.0 - self.push_loss_weight) / self.push_loss_weight + grad_static *= pull_loss_weight + + # Decide how to store the impostors + if self.impostor_store == "sparse": + use_sparse = True + elif self.impostor_store == "list": + use_sparse = False + else: + # 'auto': Use a heuristic based on the data set size + use_sparse = X.shape[0] > 6500 + + # Create a dictionary of parameters to be passed to the optimizer + disp = self.verbose - 2 if self.verbose > 1 else -1 + optimizer_params = { + "method": "L-BFGS-B", + "fun": self._loss_grad_lbfgs, + "jac": True, + "args": (X, y, classes, target_neighbors, grad_static, use_sparse), + "x0": transformation, + "tol": self.tol, + "options": dict(maxiter=self.max_iter, disp=disp), + "callback": self._callback, + } + + # Call the optimizer + self.n_iter_ = 0 + opt_result = minimize(**optimizer_params) + + # Reshape the solution found by the optimizer + self.components_ = opt_result.x.reshape(-1, X.shape[1]) + + # Stop timer + t_train = time.time() - t_train + if self.verbose: + cls_name = self.__class__.__name__ + + # Warn the user if the algorithm did not converge + if not opt_result.success: + warn( + "[{}] LMNN did not converge: {}".format( + cls_name, opt_result.message + ), + ConvergenceWarning, + ) + + print("[{}] Training took {:8.2f}s.".format(cls_name, t_train)) + + # Optionally store information returned by the optimizer + if self.store_opt_result: + self.opt_result_ = opt_result + + return self + + def transform(self, X): + """Applies the learned transformation to the given data. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Data samples. + + Returns + ------- + X_transformed : ndarray of shape (n_samples, n_components) + The data samples transformed. + + Raises + ------ + NotFittedError + If :meth:`fit` has not been called before. + + Notes + ----- + A simple dot product is necessary and sufficient to transform the + inputs into the learned subspace. Orthogonality of the components is + only enforced upon initialization if PCA is used (`init="pca"`). + """ + + check_is_fitted(self, ["components_"]) + X = check_array(X) + + return np.dot(X, self.components_.T) + + def _validate_params_inner(self, X, y): + """Validate parameters as soon as :meth:`fit` is called. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The training samples. + + y : array-like of shape (n_samples,) + The corresponding training labels. + + Returns + ------- + X : array of shape (n_samples, n_features) + The validated training samples. + + y : array of shape (n_samples,) + The validated training labels, encoded to be integers in the + range [0, n_classes). + + classes_inverse_non_singleton : ndarray of shape (n_classes_non_singleton,) + The non-singleton classes, encoded as integers in [0, n_classes). + + init : str or ndarray of shape (n_features_a, n_features) + The validated initialization of the linear transformation. + + Raises + ------ + TypeError + If a parameter is not an instance of the desired type. + + ValueError + If a parameter's value violates its legal value range or if the + combination of two or more given parameters is incompatible. + """ + + # Validate the inputs + X, y = check_X_y(X, y, ensure_min_samples=2) + check_classification_targets(y) + + # Find the appearing classes and the class index of each of the samples + classes, y = np.unique(y, return_inverse=True) + classes_inverse = np.arange(len(classes)) + + # Ignore classes that have less than two samples (singleton classes) + class_sizes = np.bincount(y) + mask_singleton_class = class_sizes == 1 + singleton_classes = np.where(mask_singleton_class)[0] + if len(singleton_classes): + warn( + f"There are {len(singleton_classes)} singleton classes that will be " + "ignored during training. A copy of the inputs `X` and `y` will be " + "made." + ) + mask_singleton_sample = np.asarray([yi in singleton_classes for yi in y]) + X = X[~mask_singleton_sample].copy() + y = y[~mask_singleton_sample].copy() + + # Check that there are at least 2 non-singleton classes + n_classes_non_singleton = len(classes) - len(singleton_classes) + if n_classes_non_singleton < 2: + raise ValueError( + "LargeMarginNearestNeighbor needs at least 2 " + f"non-singleton classes, got {n_classes_non_singleton}." + ) + + classes_inverse_non_singleton = classes_inverse[~mask_singleton_class] + + # Check the preferred dimensionality of the transformed samples + if self.n_components is not None and self.n_components > X.shape[1]: + raise ValueError( + "The preferred output dimensionality " + f"`n_components` ({self.n_components}) cannot be greater " + f"than the given data dimensionality ({X.shape[1]})!" + ) + + # If warm_start is enabled, check that the inputs are consistent + if self.warm_start and hasattr(self, "components_"): + if self.components_.shape[1] != X.shape[1]: + raise ValueError( + f"The new inputs dimensionality ({X.shape[1]}) does not " + "match the input dimensionality of the " + f"previously learned transformation ({self.components_.shape[1]})." + ) + + # Check how the linear transformation should be initialized + init = self.init + if not isinstance(init, str): + init = check_array(init) + + # Assert that init.shape[1] = X.shape[1] + if init.shape[1] != X.shape[1]: + raise ValueError( + f"The input dimensionality ({init.shape[1]}) of the given " + "linear transformation `init` must match the " + f"dimensionality of the given inputs `X` ({X.shape[1]})." + ) + + # Assert that init.shape[0] <= init.shape[1] + if init.shape[0] > init.shape[1]: + raise ValueError( + f"The output dimensionality ({init.shape[0]}) of the given " + "linear transformation `init` cannot be " + f"greater than its input dimensionality ({init.shape[1]})." + ) + + if self.n_components is not None: + # Assert that self.n_components = init.shape[0] + if self.n_components != init.shape[0]: + raise ValueError( + "The preferred output dimensionality " + f"`n_components` ({self.n_components}) does not match " + "the output dimensionality of the given " + f"linear transformation `init` ({init.shape[0]})!" + ) + + # Check the preferred number of neighbors + min_non_singleton_size = class_sizes[~mask_singleton_class].min() + if self.n_neighbors >= min_non_singleton_size: + warn( + f"`n_neighbors` (={self.n_neighbors}) is not less than the number of " + "samples in the smallest non-singleton class (={}). " + "`n_neighbors_` will be set to " + f"{min_non_singleton_size, min_non_singleton_size - 1} for estimation." + ) + + self.n_neighbors_ = min(self.n_neighbors, min_non_singleton_size - 1) + + neighbors_params = self.neighbors_params + if neighbors_params is not None: + neighbors_params.setdefault("n_jobs", self.n_jobs) + # Attempt to instantiate a NearestNeighbors instance here to + # raise any errors before actually fitting + NearestNeighbors(n_neighbors=self.n_neighbors_, **neighbors_params) + + return X, y, classes_inverse_non_singleton, init + + def _initialize(self, X, init): + """Initialize the transformation matrix. + + Parameters + ---------- + X : ndarray of shape (n_samples, n_features) + The training samples. + + init : str or ndarray of shape (n_features_a, n_features) + The initialization of the linear transformation. + + Returns + ------- + transformation : ndarray of shape (n_components, n_features) + The initialized linear transformation. + """ + + transformation = init + if self.warm_start and hasattr(self, "components_"): + transformation = self.components_ + + elif isinstance(init, np.ndarray): + pass + + elif init == "pca": + pca = PCA(n_components=self.n_components, random_state=self.random_state_) + t_pca = time.time() + if self.verbose: + print( + "[{}] Finding principal components...".format( + self.__class__.__name__ + ) + ) + sys.stdout.flush() + + pca.fit(X) + if self.verbose: + t_pca = time.time() - t_pca + print( + "[{}] Found principal components in {:5.2f}s.".format( + self.__class__.__name__, t_pca + ) + ) + + transformation = pca.components_ + + elif init == "identity": + if self.n_components is None: + transformation = np.eye(X.shape[1]) + else: + transformation = np.eye(self.n_components, X.shape[1]) + + return transformation + + def _select_target_neighbors_wrapper(self, X, y, classes=None): + """Find the target neighbors of each of the data samples. + + Parameters + ---------- + X : ndarray of shape (n_samples, n_features) + The training samples. + + y : ndarray of shape (n_samples,) + The corresponding training labels indices. + + classes : ndarray of shape (n_classes,), default=None + The non-singleton classes, encoded as integers in `[0, n_classes)`. + If None (default), they will be inferred from `y`. + + Returns + ------- + target_neighbors : ndarray of shape (n_samples, n_neighbors) + An array of neighbors indices for each of the samples. + """ + + t_start = time.time() + if self.verbose: + print( + "[{}] Finding the target neighbors...".format(self.__class__.__name__) + ) + sys.stdout.flush() + + neighbors_params = self.neighbors_params + if neighbors_params is None: + neighbors_params = {} + + neighbors_params.setdefault("n_jobs", self.n_jobs) + target_neighbors = _select_target_neighbors( + X, y, self.n_neighbors_, classes=classes, **neighbors_params + ) + + if self.verbose: + print( + "[{}] Found the target neighbors in {:5.2f}s.".format( + self.__class__.__name__, time.time() - t_start + ) + ) + + return target_neighbors + + def _compute_grad_static(self, X, target_neighbors): + """Compute the gradient contributed by the target neighbors. + + Parameters + ---------- + X : ndarray of shape (n_samples, n_features) + The training samples. + + target_neighbors : ndarray of shape (n_samples, n_neighbors) + The k nearest neighbors of each of the samples from the same class. + + Returns + ------- + grad_target_neighbors : ndarray of shape (n_features, n_features) + An array with the sum of all outer products of + `(sample, target_neighbor)` pairs. + """ + + start_time = time.time() + if self.verbose: + print( + "[{}] Computing static part of the gradient...".format( + self.__class__.__name__ + ) + ) + + n_samples, n_neighbors = target_neighbors.shape + row = np.repeat(range(n_samples), n_neighbors) + col = target_neighbors.ravel() + tn_graph = csr_matrix( + (np.ones(target_neighbors.size), (row, col)), shape=(n_samples, n_samples) + ) + grad_target_neighbors = _sum_weighted_outer_differences(X, tn_graph) + + if self.verbose: + duration = time.time() - start_time + print( + "[{}] Computed static part of the gradient in {:5.2f}s.".format( + self.__class__.__name__, duration + ) + ) + + return grad_target_neighbors + + def _callback(self, transformation): + """Called after each iteration of the optimizer. + + Parameters + ---------- + transformation : ndarray of shape(n_components * n_features,) + The solution computed by the optimizer in this iteration. + """ + if self.callback is not None: + self.callback(transformation, self.n_iter_) + + self.n_iter_ += 1 + + def _loss_grad_lbfgs( + self, transformation, X, y, classes, target_neighbors, grad_static, use_sparse + ): + """Compute the loss and the loss gradient w.r.t. ``transformation``. + + Parameters + ---------- + transformation : ndarray of shape (n_components * n_features,) + The current (flattened) linear transformation. + + X : ndarray of shape (n_samples, n_features) + The training samples. + + y : ndarray of shape (n_samples,) + The corresponding training labels. + + classes : ndarray of shape (n_classes,) + The non-singleton classes, encoded as integers in `[0, n_classes)`. + + target_neighbors : ndarray of shape (n_samples, n_neighbors) + The target neighbors of each of the training samples. + + grad_static : ndarray of shape (n_features, n_features) + The (weighted) gradient component caused by target neighbors, + that stays fixed throughout the algorithm. + + use_sparse : bool + Whether to use a sparse matrix or lists to store the impostors. + + Returns + ------- + loss : float + The loss based on the given transformation. + + grad : ndarray of shape (n_components * n_features,) + The new (flattened) gradient of the loss. + """ + + n_samples, n_features = X.shape + transformation = transformation.reshape(-1, n_features) + self.components_ = transformation + + if self.n_iter_ == 0: + self.n_iter_ += 1 + if self.verbose: + header_fields = [ + "Iteration", + "Objective Value", + "#Active Triplets", + "Time(s)", + ] + header_fmt = "{:>10} {:>20} {:>20} {:>10}" + header = header_fmt.format(*header_fields) + cls_name = self.__class__.__name__ + print("[{}]".format(cls_name)) + print("[{}] {}".format(cls_name, header)) + print("[{}] {}".format(cls_name, "-" * len(header))) + + t_funcall = time.time() + + # transform without checks + X_transformed = np.dot(X, transformation.T) + + # Compute squared distances to the target neighbors + n_neighbors = target_neighbors.shape[1] + dist_tn = np.zeros((n_samples, n_neighbors)) + for k in range(n_neighbors): + ind_tn = target_neighbors[:, k] + dist_tn[:, k] = row_norms( + X_transformed - X_transformed[ind_tn], squared=True + ) + + # Add the margin to all squared distances to target neighbors + dist_tn += 1 + + # Find the impostors and compute squared distances to them + impostors_graph = self._find_impostors( + X_transformed, y, classes, dist_tn[:, -1], use_sparse + ) + + # Compute the push loss and its gradient + loss, grad_new, n_active_triplets = _compute_push_loss( + X, target_neighbors, dist_tn, impostors_graph + ) + + # Compute the total gradient + grad = np.dot(transformation, grad_static + grad_new) + grad *= 2 + + # Add the (weighted) pull loss to the total loss + metric = np.dot(transformation.T, transformation) + loss += np.dot(grad_static.ravel(), metric.ravel()) + + if self.verbose: + t_funcall = time.time() - t_funcall + values_fmt = "[{}] {:>10} {:>20.6e} {:>20,} {:>10.2f}" + print( + values_fmt.format( + self.__class__.__name__, + self.n_iter_, + loss, + n_active_triplets, + t_funcall, + ) + ) + sys.stdout.flush() + + return loss, grad.ravel() + + def _find_impostors(self, X_transformed, y, classes, margin_radii, use_sparse=True): + """Compute the (sample, impostor) pairs exactly. + + Parameters + ---------- + X_transformed : ndarray of shape (n_samples, n_components) + An array of transformed samples. + + y : ndarray of shape (n_samples,) + The corresponding (possibly encoded) class labels. + + classes : ndarray of shape (n_classes,) + The non-singleton classes, encoded as integers in `[0, n_classes)`. + + margin_radii : ndarray of shape (n_samples,) + Squared distances of samples to their farthest target neighbors + plus margin. + + use_sparse : bool, default=True + Whether to use a sparse matrix or lists to store the + `(sample, impostor)` pairs. + + Returns + ------- + impostors_graph : coo_matrix, shape (n_samples, n_samples) + If at least one of two violations is active (sample i is an + impostor to j or sample j is an impostor to i), then one of the + two entries (i, j) or (j, i) will hold the squared distance + between the two samples. Otherwise both entries will be zero. + + """ + n_samples = X_transformed.shape[0] + + if use_sparse: + # Initialize a sparse (indicator) matrix for impostors storage + impostors_sp = csr_matrix((n_samples, n_samples), dtype=np.int8) + for class_id in classes[:-1]: + ind_in = np.where(y == class_id)[0] + ind_out = np.where(y > class_id)[0] + + # Split impostors computation into chunks that fit in memory + imp_ind = _find_impostors_chunked( + X_transformed[ind_in], + X_transformed[ind_out], + margin_radii[ind_in], + margin_radii[ind_out], + ) + + if len(imp_ind): + # Subsample impostors if they are too many + if len(imp_ind) > self.max_impostors: + imp_ind = self.random_state_.choice( + imp_ind, self.max_impostors, replace=False + ) + + dims = (len(ind_out), len(ind_in)) + ii, jj = np.unravel_index(imp_ind, dims) + # Convert indices to refer to the original data matrix + imp_row = ind_out[ii] + imp_col = ind_in[jj] + new_imp = csr_matrix( + (np.ones(len(imp_row), dtype=np.int8), (imp_row, imp_col)), + dtype=np.int8, + shape=(n_samples, n_samples), + ) + impostors_sp = impostors_sp + new_imp + + impostors_sp = impostors_sp.tocoo(copy=False) + imp_row = impostors_sp.row + imp_col = impostors_sp.col + + # Make sure we do not exceed max_impostors + n_impostors = len(imp_row) + if n_impostors > self.max_impostors: + ind_sampled = self.random_state_.choice( + n_impostors, self.max_impostors, replace=False + ) + imp_row = imp_row[ind_sampled] + imp_col = imp_col[ind_sampled] + + imp_dist = _paired_distances_chunked(X_transformed, imp_row, imp_col) + else: + # Initialize lists for impostors storage + imp_row, imp_col, imp_dist = [], [], [] + for class_id in classes[:-1]: + ind_in = np.where(y == class_id)[0] + ind_out = np.where(y > class_id)[0] + + # Split impostors computation into chunks that fit in memory + imp_ind, dist_batch = _find_impostors_chunked( + X_transformed[ind_in], + X_transformed[ind_out], + margin_radii[ind_in], + margin_radii[ind_out], + return_distance=True, + ) + + if len(imp_ind): + # Subsample impostors if they are too many + if len(imp_ind) > self.max_impostors: + ind_sampled = self.random_state_.choice( + len(imp_ind), self.max_impostors, replace=False + ) + imp_ind = imp_ind[ind_sampled] + dist_batch = dist_batch[ind_sampled] + + dims = (len(ind_out), len(ind_in)) + ii, jj = np.unravel_index(imp_ind, dims) + # Convert indices to refer to the original data matrix + imp_row.extend(ind_out[ii]) + imp_col.extend(ind_in[jj]) + imp_dist.extend(dist_batch) + + imp_row = np.asarray(imp_row, dtype=np.intp) + imp_col = np.asarray(imp_col, dtype=np.intp) + imp_dist = np.asarray(imp_dist) + + # Make sure we do not exceed max_impostors + n_impostors = len(imp_row) + if n_impostors > self.max_impostors: + ind_sampled = self.random_state_.choice( + n_impostors, self.max_impostors, replace=False + ) + imp_row = imp_row[ind_sampled] + imp_col = imp_col[ind_sampled] + imp_dist = imp_dist[ind_sampled] + + impostors_graph = coo_matrix( + (imp_dist, (imp_row, imp_col)), shape=(n_samples, n_samples) + ) + + return impostors_graph + + +######################## +# Some core functions # +####################### + + +def _select_target_neighbors(X, y, n_neighbors, classes=None, **nn_kwargs): + """Find the target neighbors of each of the training samples. + + Parameters + ---------- + X : ndarray of shape (n_samples, n_features) + The training samples. + + y : ndarray of shape (n_samples,) + The corresponding (encoded) training labels. + + n_neighbors : int + The number of target neighbors to select for each sample in `X`. + + classes : ndarray of shape (n_classes,), default=None + The non-singleton classes, encoded as integers in `[0, n_classes)`. + If None (default), they will be inferred from `y`. + + **nn_kwargs : dict + Parameters to be passed to a :class:`~sklearn.neighbors.NearestNeighbors` + instance except from `n_neighbors`. + + Returns + ------- + target_neighbors : ndarray of shape (n_samples, n_neighbors) + The indices of the target neighbors of each training sample. + """ + target_neighbors = np.zeros((X.shape[0], n_neighbors), dtype=np.intp) + + nn = NearestNeighbors(n_neighbors=n_neighbors, **nn_kwargs) + + if classes is None: + classes = np.unique(y) + + for class_id in classes: + ind_class = np.where(y == class_id)[0] + nn.fit(X[ind_class]) + neigh_ind = nn.kneighbors(return_distance=False) + target_neighbors[ind_class] = ind_class[neigh_ind] + + return target_neighbors + + +def _find_impostors_chunked(X_in, X_out, radii_in, radii_out, return_distance=False): + """Find (sample, impostor) pairs in chunks to avoid large memory usage. + + Parameters + ---------- + X_in : ndarray of shape (n_samples_a, n_components) + Transformed data samples that belong to class A. + + X_out : ndarray of shape (n_samples_b, n_components) + Transformed data samples that belong to classes different from A. + + radii_in : ndarray of shape (n_samples_a,) + Squared distances of the samples in `X_in` to their margins. + + radii_out : ndarray of shape (n_samples_b,) + Squared distances of the samples in `X_out` to their margins. + + return_distance : bool, optional (default=False) + Whether to return the squared distances to the impostors. + + Returns + ------- + imp_indices : ndarray of shape (n_impostors,) + Unraveled indices referring to a matrix of shape + (n_samples_b, n_samples_a). Index pair (i, j) is returned (unraveled) + if either sample i is an impostor to sample j or sample j is an + impostor to sample i. + + imp_distances : ndarray of shape (n_impostors,), optional + imp_distances[i] is the squared distance between samples imp_row[i] and + imp_col[i], where + imp_row, imp_col = np.unravel_index(imp_indices, (n_samples_b, + n_samples_a)) + """ + n_samples_b = X_out.shape[0] + row_bytes = X_in.shape[0] * X_in.itemsize + chunk_size = get_chunk_n_rows(row_bytes, max_n_rows=n_samples_b) + imp_indices, imp_distances = [], [] + + # X_in squared norm stays constant, so pre-compute it to get a speed-up + X_in_norm_squared = row_norms(X_in, squared=True)[np.newaxis, :] + for sl in gen_batches(n_samples_b, chunk_size): + # The function `sklearn.metrics.pairwise.euclidean_distances` would + # add an extra ~8% time of computation due to input validation on + # every chunk and another ~8% due to clipping of negative values. + distances_chunk = _euclidean_distances_without_checks( + X_out[sl], X_in, squared=True, Y_norm_squared=X_in_norm_squared, clip=False + ) + + ind_out = np.where((distances_chunk < radii_in[None, :]).ravel())[0] + ind_in = np.where((distances_chunk < radii_out[sl, None]).ravel())[0] + ind = np.unique(np.concatenate((ind_out, ind_in))) + + if len(ind): + ind_plus_offset = ind + sl.start * X_in.shape[0] + imp_indices.extend(ind_plus_offset) + + if return_distance: + # We only need to do clipping if we return the distances. + distances_chunk = distances_chunk.ravel()[ind] + # Clip only the indexed (unique) distances + np.maximum(distances_chunk, 0, out=distances_chunk) + imp_distances.extend(distances_chunk) + + imp_indices = np.asarray(imp_indices) + + if return_distance: + return imp_indices, np.asarray(imp_distances) + else: + return imp_indices + + +def _compute_push_loss(X, target_neighbors, inflated_dist_tn, impostors_graph): + """Compute the loss. + + Parameters + ---------- + X : ndarray of shape (n_samples, n_features) + The training input samples. + + target_neighbors : ndarray of shape (n_samples, n_neighbors) + Indices of target neighbors of each sample in X. + + inflated_dist_tn : ndarray of shape (n_samples, n_neighbors) + Squared distances of each sample to their target neighbors plus margin. + + impostors_graph : coo_matrix, shape (n_samples, n_samples) + If at least one of two violations is active (sample i is an impostor + to j or sample j is an impostor to i), then one of the two entries + (i, j) or (j, i) will hold the squared distance between the two + samples. Otherwise both entries will be zero. + + Returns + ------- + loss : float + The push loss caused by the given target neighbors and impostors. + + grad : ndarray of shape (n_features, n_features) + The gradient of the push loss. + + n_active_triplets : int + The number of active triplet constraints. + """ + + n_samples, n_neighbors = inflated_dist_tn.shape + imp_row = impostors_graph.row + imp_col = impostors_graph.col + dist_impostors = impostors_graph.data + + loss = 0 + shape = (n_samples, n_samples) + A0 = csr_matrix(shape) + sample_range = range(n_samples) + n_active_triplets = 0 + for k in reversed(range(n_neighbors)): + # Consider margin violations to the samples in imp_row + loss1 = np.maximum(inflated_dist_tn[imp_row, k] - dist_impostors, 0) + ac = np.where(loss1 > 0)[0] + n_active_triplets += len(ac) + A1 = csr_matrix((2 * loss1[ac], (imp_row[ac], imp_col[ac])), shape) + + # Consider margin violations to the samples in imp_col + loss2 = np.maximum(inflated_dist_tn[imp_col, k] - dist_impostors, 0) + ac = np.where(loss2 > 0)[0] + n_active_triplets += len(ac) + A2 = csc_matrix((2 * loss2[ac], (imp_row[ac], imp_col[ac])), shape) + + # Update the loss + loss += np.dot(loss1, loss1) + np.dot(loss2, loss2) + + # Update the weight matrix for gradient computation + val = (A1.sum(1).ravel() + A2.sum(0)).getA1() + A3 = csr_matrix((val, (sample_range, target_neighbors[:, k])), shape) + A0 = A0 - A1 - A2 + A3 + + grad = _sum_weighted_outer_differences(X, A0) + + return loss, grad, n_active_triplets + + +########################## +# Some helper functions # +######################### + + +def _paired_distances_chunked(X, ind_a, ind_b, squared=True): + """Equivalent to row_norms(X[ind_a] - X[ind_b], squared=squared). + + Parameters + ---------- + X : ndarray of shape (n_samples, n_features) + An array of data samples. + + ind_a : ndarray of shape (n_indices,) + An array of indices referring to samples in X. + + ind_b : ndarray of shape (n_indices,) + Another array of indices referring to samples in X. + + squared : bool (default=True) + Whether to return the squared distances. + + Returns + ------- + distances : ndarray of shape (n_indices,) + An array of pairwise, optionally squared, distances. + """ + + n_pairs = len(ind_a) + row_bytes = X.shape[1] * X.itemsize + chunk_size = get_chunk_n_rows(row_bytes, max_n_rows=n_pairs) + + distances = np.zeros(n_pairs) + for sl in gen_batches(n_pairs, chunk_size): + distances[sl] = row_norms(X[ind_a[sl]] - X[ind_b[sl]], True) + + return distances if squared else np.sqrt(distances, out=distances) + + +def _sum_weighted_outer_differences(X, weights): + """Compute the sum of weighted outer pairwise differences. + + Parameters + ---------- + X : ndarray of shape (n_samples, n_features) + An array of data samples. + + weights : sparse matrix, shape (n_samples, n_samples) + A sparse weights matrix in CSR format. + + Returns + ------- + sum_weighted_outer_diffs : ndarray of shape (n_features, n_features) + The sum of all outer weighted differences. + """ + + weights_sym = weights + weights.T + diagonal = weights_sym.sum(1).getA() + laplacian_dot_X = diagonal * X - safe_sparse_dot(weights_sym, X, dense_output=True) + sum_weighted_outer_diffs = np.dot(X.T, laplacian_dot_X) + + return sum_weighted_outer_diffs diff --git a/sklearn/neighbors/_nca.py b/sklearn/neighbors/_nca.py index 96cdc3052c66e..6e60ba921be38 100644 --- a/sklearn/neighbors/_nca.py +++ b/sklearn/neighbors/_nca.py @@ -3,7 +3,7 @@ """ # Authors: William de Vazelhes -# John Chiotellis +# John Chiotellis # License: BSD 3 clause from warnings import warn diff --git a/sklearn/neighbors/tests/test_lmnn.py b/sklearn/neighbors/tests/test_lmnn.py new file mode 100644 index 0000000000000..5aeceb95abd66 --- /dev/null +++ b/sklearn/neighbors/tests/test_lmnn.py @@ -0,0 +1,633 @@ +import numpy as np +from scipy.optimize import check_grad + +from sklearn import datasets +from sklearn.neighbors import LargeMarginNearestNeighbor +from sklearn.neighbors import KNeighborsClassifier +from sklearn.neighbors._lmnn import _paired_distances_chunked +from sklearn.neighbors._lmnn import _compute_push_loss +from sklearn.metrics.pairwise import paired_euclidean_distances +from sklearn.model_selection import train_test_split +from sklearn.utils import check_random_state +from sklearn.utils.extmath import row_norms +from sklearn.exceptions import ConvergenceWarning +from sklearn.utils._testing import assert_array_equal +from sklearn.utils._testing import assert_allclose +from sklearn.utils._testing import assert_raises +from sklearn.utils._testing import assert_raise_message +from sklearn.utils._testing import assert_warns_message + + +rng = np.random.RandomState(0) +# load and shuffle iris dataset +iris = datasets.load_iris() +perm = rng.permutation(iris.target.size) +iris_data = iris.data[perm] +iris_target = iris.target[perm] + +# load and shuffle digits +digits = datasets.load_digits() +perm = rng.permutation(digits.target.size) +digits_data = digits.data[perm] +digits_target = digits.target[perm] + + +def test_neighbors_iris(): + # Sanity checks on the iris dataset + # Puts three points of each label in the plane and performs a + # nearest neighbor query on points near the decision boundary. + + lmnn = LargeMarginNearestNeighbor(n_neighbors=1) + lmnn.fit(iris_data, iris_target) + knn = KNeighborsClassifier(n_neighbors=lmnn.n_neighbors_) + LX = lmnn.transform(iris_data) + knn.fit(LX, iris_target) + y_pred = knn.predict(LX) + + assert_array_equal(y_pred, iris_target) + + lmnn.set_params(n_neighbors=9) + lmnn.fit(iris_data, iris_target) + knn = KNeighborsClassifier(n_neighbors=lmnn.n_neighbors_) + knn.fit(LX, iris_target) + + assert knn.score(LX, iris_target) > 0.95 + + +def test_neighbors_digits(): + # Sanity check on the digits dataset + # the 'brute' algorithm has been observed to fail if the input + # dtype is uint8 due to overflow in distance calculations. + + X = digits_data.astype('uint8') + y = digits_target + n_samples, n_features = X.shape + train_test_boundary = int(n_samples * 0.8) + train = np.arange(0, train_test_boundary) + test = np.arange(train_test_boundary, n_samples) + X_train, y_train, X_test, y_test = X[train], y[train], X[test], y[test] + + k = 1 + lmnn = LargeMarginNearestNeighbor(n_neighbors=k, max_iter=30) + lmnn.fit(X_train, y_train) + knn = KNeighborsClassifier(n_neighbors=k) + knn.fit(lmnn.transform(X_train), y_train) + score_uint8 = knn.score(lmnn.transform(X_test), y_test) + + knn.fit(lmnn.transform(X_train.astype(float)), y_train) + score_float = knn.score(lmnn.transform(X_test.astype(float)), y_test) + + assert score_uint8 == score_float + + +def test_params_validation(): + # Test that invalid parameters raise value error + X = np.arange(12).reshape(4, 3) + y = [1, 1, 2, 2] + LMNN = LargeMarginNearestNeighbor + + # TypeError + assert_raises(TypeError, LMNN(n_neighbors=1.3).fit, X, y) + assert_raises(TypeError, LMNN(max_iter='21').fit, X, y) + assert_raises(TypeError, LMNN(verbose='true').fit, X, y) + assert_raises(TypeError, LMNN(max_impostors=23.1).fit, X, y) + assert_raises(TypeError, LMNN(tol='1').fit, X, y) + assert_raises(TypeError, LMNN(n_components='invalid').fit, X, y) + assert_raises(TypeError, LMNN(n_jobs='yes').fit, X, y) + assert_raises(TypeError, LMNN(warm_start=1).fit, X, y) + assert_raises(TypeError, LMNN(impostor_store=0.5).fit, X, y) + assert_raises(TypeError, LMNN(neighbors_params=65).fit, X, y) + assert_raises(TypeError, LMNN(push_loss_weight='0.3').fit, X, y) + + # ValueError + assert_raise_message(ValueError, + "`init` must be 'pca', 'identity', or a numpy " + "array of shape (n_components, n_features).", + LMNN(init=1).fit, X, y) + + assert_raise_message(ValueError, + '`n_neighbors`= -1, must be >= 1.', + LMNN(n_neighbors=-1).fit, X, y) + + assert_raise_message(ValueError, + '`n_neighbors`= {}, must be <= {}.' + .format(X.shape[0], X.shape[0] - 1), + LMNN(n_neighbors=X.shape[0]).fit, X, y) + + assert_raise_message(ValueError, + '`max_iter`= -1, must be >= 1.', + LMNN(max_iter=-1).fit, X, y) + assert_raise_message(ValueError, + '`max_impostors`= -1, must be >= 1.', + LMNN(max_impostors=-1).fit, X, y) + assert_raise_message(ValueError, + "`impostor_store` must be 'auto', 'sparse' " + "or 'list'.", + LMNN(impostor_store='dense').fit, X, y) + + assert_raise_message(ValueError, + '`push_loss_weight`= 2.0, must be <= 1.0.', + LMNN(push_loss_weight=2.).fit, X, y) + + assert_raise_message(ValueError, + '`push_loss_weight` cannot be zero.', + LMNN(push_loss_weight=0.).fit, X, y) + + rng = np.random.RandomState(42) + init = rng.rand(5, 3) + assert_raise_message(ValueError, + 'The output dimensionality ({}) of the given linear ' + 'transformation `init` cannot be greater than its ' + 'input dimensionality ({}).' + .format(init.shape[0], init.shape[1]), + LMNN(init=init).fit, X, y) + + n_components = 10 + assert_raise_message(ValueError, + 'The preferred output dimensionality ' + '`n_components` ({}) cannot be greater ' + 'than the given data dimensionality ({})!' + .format(n_components, X.shape[1]), + LMNN(n_components=n_components).fit, X, y) + + n_jobs = 0 + assert_raise_message(ValueError, + 'n_jobs == 0 in Parallel has no meaning', + LMNN(n_jobs=n_jobs).fit, X, y) + + # test min_class_size < 2 + y = [1, 1, 1, 2] + assert_raise_message(ValueError, + 'LargeMarginNearestNeighbor needs at least 2 ' + 'non-singleton classes, got 1.', + LMNN(n_neighbors=1).fit, X, y) + + +def test_n_neighbors(): + X = np.arange(12).reshape(4, 3) + y = [1, 1, 2, 2] + + lmnn = LargeMarginNearestNeighbor(n_neighbors=2) + assert_warns_message(UserWarning, + '`n_neighbors` (=2) is not less than the number of ' + 'samples in the smallest non-singleton class (=2). ' + '`n_neighbors_` will be set to 1 for estimation.', + lmnn.fit, X, y) + + +def test_n_components(): + X = np.arange(12).reshape(4, 3) + y = [1, 1, 2, 2] + + rng = np.random.RandomState(42) + init = rng.rand(X.shape[1] - 1, 3) + + # n_components = X.shape[1] != transformation.shape[0] + n_components = X.shape[1] + lmnn = LargeMarginNearestNeighbor(init=init, n_components=n_components) + assert_raise_message(ValueError, + 'The preferred output dimensionality ' + '`n_components` ({}) does not match ' + 'the output dimensionality of the given ' + 'linear transformation `init` ({})!' + .format(n_components, init.shape[0]), + lmnn.fit, X, y) + + # n_components > X.shape[1] + n_components = X.shape[1] + 2 + lmnn = LargeMarginNearestNeighbor(init=init, n_components=n_components) + assert_raise_message(ValueError, + 'The preferred output dimensionality ' + '`n_components` ({}) cannot be greater ' + 'than the given data dimensionality ({})!' + .format(n_components, X.shape[1]), + lmnn.fit, X, y) + + # n_components < X.shape[1] + lmnn = LargeMarginNearestNeighbor(n_components=2, init='identity') + lmnn.fit(X, y) + + +def test_init_transformation(): + rng = np.random.RandomState(42) + X, y = datasets.make_classification(n_samples=30, n_features=5, + n_redundant=0, random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + # Initialize with identity + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, init='identity') + lmnn.fit(X_train, y_train) + + # Initialize with PCA + lmnn_pca = LargeMarginNearestNeighbor(n_neighbors=3, init='pca') + lmnn_pca.fit(X_train, y_train) + + # Initialize with a transformation given by the user + init = rng.rand(X.shape[1], X.shape[1]) + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, init=init) + lmnn.fit(X_train, y_train) + + # init.shape[1] must match X.shape[1] + init = rng.rand(X.shape[1], X.shape[1] + 1) + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, init=init) + assert_raise_message(ValueError, + 'The input dimensionality ({}) of the given ' + 'linear transformation `init` must match the ' + 'dimensionality of the given inputs `X` ({}).' + .format(init.shape[1], X.shape[1]), + lmnn.fit, X_train, y_train) + + # init.shape[0] must be <= init.shape[1] + init = rng.rand(X.shape[1] + 1, X.shape[1]) + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, init=init) + assert_raise_message(ValueError, + 'The output dimensionality ({}) of the given ' + 'linear transformation `init` cannot be ' + 'greater than its input dimensionality ({}).' + .format(init.shape[0], init.shape[1]), + lmnn.fit, X_train, y_train) + + # init.shape[0] must match n_components + init = rng.rand(X.shape[1], X.shape[1]) + n_components = X.shape[1] - 2 + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, init=init, + n_components=n_components) + assert_raise_message(ValueError, + 'The preferred output dimensionality ' + '`n_components` ({}) does not match ' + 'the output dimensionality of the given ' + 'linear transformation `init` ({})!' + .format(n_components, init.shape[0]), + lmnn.fit, X_train, y_train) + + +def test_warm_start_validation(): + X, y = datasets.make_classification(n_samples=30, n_features=5, + n_classes=4, n_redundant=0, + n_informative=5, random_state=0) + + lmnn = LargeMarginNearestNeighbor(warm_start=True, max_iter=5) + lmnn.fit(X, y) + + X_less_features, y = \ + datasets.make_classification(n_samples=30, n_features=4, n_classes=4, + n_redundant=0, n_informative=4, + random_state=0) + assert_raise_message(ValueError, + 'The new inputs dimensionality ({}) does not ' + 'match the input dimensionality of the ' + 'previously learned transformation ({}).' + .format(X_less_features.shape[1], + lmnn.components_.shape[1]), + lmnn.fit, X_less_features, y) + + +def test_warm_start_effectiveness(): + # A 1-iteration second fit on same data should give almost same result + # with warm starting, and quite different result without warm starting. + + X, y = datasets.make_classification(n_samples=30, n_features=5, + n_redundant=0, random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y) + n_iter = 10 + + lmnn_warm = LargeMarginNearestNeighbor(n_neighbors=3, warm_start=True, + max_iter=n_iter, random_state=0) + lmnn_warm.fit(X_train, y_train) + transformation_warm = lmnn_warm.components_ + lmnn_warm.max_iter = 1 + lmnn_warm.fit(X_train, y_train) + transformation_warm_plus_one = lmnn_warm.components_ + + lmnn_cold = LargeMarginNearestNeighbor(n_neighbors=3, warm_start=False, + max_iter=n_iter, random_state=0) + lmnn_cold.fit(X_train, y_train) + transformation_cold = lmnn_cold.components_ + lmnn_cold.max_iter = 1 + lmnn_cold.fit(X_train, y_train) + transformation_cold_plus_one = lmnn_cold.components_ + + diff_warm = np.sum(np.abs(transformation_warm_plus_one - + transformation_warm)) + diff_cold = np.sum(np.abs(transformation_cold_plus_one - + transformation_cold)) + + assert diff_warm < 2.0, "Transformer changed significantly after one " \ + "iteration even though it was warm-started." + + assert diff_cold > diff_warm, "Cold-started transformer changed less " \ + "significantly than warm-started " \ + "transformer after one iteration." + + +def test_max_impostors(): + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, max_impostors=1, + impostor_store='list') + lmnn.fit(iris_data, iris_target) + + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, max_impostors=1, + impostor_store='sparse') + lmnn.fit(iris_data, iris_target) + + +def test_neighbors_params(): + from scipy.spatial.distance import hamming + + params = {'algorithm': 'brute', 'metric': hamming} + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, neighbors_params=params) + lmnn.fit(iris_data, iris_target) + components_hamming = lmnn.components_ + + lmnn = LargeMarginNearestNeighbor(n_neighbors=3) + lmnn.fit(iris_data, iris_target) + components_euclidean = lmnn.components_ + + assert not np.allclose(components_hamming, components_euclidean) + + +def test_impostor_store(): + k = 3 + lmnn = LargeMarginNearestNeighbor(n_neighbors=k, + init='identity', impostor_store='list') + lmnn.fit(iris_data, iris_target) + components_list = lmnn.components_ + + lmnn = LargeMarginNearestNeighbor(n_neighbors=k, + init='identity', impostor_store='sparse') + lmnn.fit(iris_data, iris_target) + components_sparse = lmnn.components_ + + assert_allclose(components_list, components_sparse, + err_msg='Toggling `impostor_store` results in a ' + 'different solution.') + + +def test_callback(capsys): + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, callback='my_cb') + assert_raise_message(ValueError, + '`callback` is not callable.', + lmnn.fit, iris_data, iris_target) + + max_iter = 10 + + def my_cb(transformation, n_iter): + assert transformation.shape == (iris_data.shape[1] ** 2,) + rem_iter = max_iter - n_iter + print('{} iterations remaining...'.format(rem_iter)) + + # assert that my_cb is called + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, callback=my_cb, + max_iter=max_iter, verbose=1) + lmnn.fit(iris_data, iris_target) + out, _ = capsys.readouterr() + + assert '{} iterations remaining...'.format(max_iter - 1) in out + + +def test_store_opt_result(): + X = iris_data + y = iris_target + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) + + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, max_iter=5, + store_opt_result=True) + lmnn.fit(X_train, y_train) + transformation = lmnn.opt_result_.x + assert transformation.size == X.shape[1]**2 + + +def test_verbose(capsys): + # assert there is proper output when verbose = 1 + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, verbose=1) + lmnn.fit(iris_data, iris_target) + out, _ = capsys.readouterr() + + # check output + assert "[LargeMarginNearestNeighbor]" in out + assert "Finding principal components" in out + assert "Finding the target neighbors" in out + assert "Computing static part of the gradient" in out + assert "Finding principal components" in out + assert "Training took" in out + + # assert by default there is no output (verbose=0) + lmnn = LargeMarginNearestNeighbor(n_neighbors=3) + lmnn.fit(iris_data, iris_target) + out, _ = capsys.readouterr() + + # check output + assert out == '' + + +def test_random_state(): + """Assert that when having more than max_impostors (forcing sampling), + the same impostors will be sampled given the same random_state and + different impostors will be sampled given a different random_state + leading to a different transformation""" + + X = iris_data + y = iris_target + + # Use init='identity' to ensure reproducibility + params = {'n_neighbors': 3, 'max_impostors': 5, 'random_state': 1, + 'max_iter': 10, 'init': 'identity'} + + lmnn = LargeMarginNearestNeighbor(**params) + lmnn.fit(X, y) + transformation_1 = lmnn.components_ + + lmnn = LargeMarginNearestNeighbor(**params) + lmnn.fit(X, y) + transformation_2 = lmnn.components_ + + # This assertion fails on 32bit systems if init='pca' + assert_allclose(transformation_1, transformation_2) + + params['random_state'] = 2 + lmnn = LargeMarginNearestNeighbor(**params) + lmnn.fit(X, y) + transformation_3 = lmnn.components_ + + assert not np.allclose(transformation_2, transformation_3) + + +def test_same_lmnn_parallel(): + X, y = datasets.make_classification(n_samples=30, n_features=5, + n_redundant=0, random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + lmnn = LargeMarginNearestNeighbor(n_neighbors=3) + lmnn.fit(X_train, y_train) + components = lmnn.components_ + + lmnn.set_params(n_jobs=3) + lmnn.fit(X_train, y_train) + components_parallel = lmnn.components_ + + assert_allclose(components, components_parallel) + + +def test_singleton_class(): + X = iris_data + y = iris_target + X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.3, stratify=y) + + # one singleton class + singleton_class = 1 + ind_singleton, = np.where(y_tr == singleton_class) + y_tr[ind_singleton] = 2 + y_tr[ind_singleton[0]] = singleton_class + + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, max_iter=30) + lmnn.fit(X_tr, y_tr) + + # One non-singleton class + X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.3, stratify=y) + ind_1, = np.where(y_tr == 1) + ind_2, = np.where(y_tr == 2) + y_tr[ind_1] = 0 + y_tr[ind_1[0]] = 1 + y_tr[ind_2] = 0 + y_tr[ind_2[0]] = 2 + + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, max_iter=30) + assert_raise_message(ValueError, + 'LargeMarginNearestNeighbor needs at least 2 ' + 'non-singleton classes, got 1.', + lmnn.fit, X_tr, y_tr) + + +def test_convergence_warning(): + + lmnn = LargeMarginNearestNeighbor(n_neighbors=3, max_iter=2, verbose=1) + cls_name = lmnn.__class__.__name__ + assert_warns_message(ConvergenceWarning, + '[{}] LMNN did not converge'.format(cls_name), + lmnn.fit, iris_data, iris_target) + + +def test_paired_distances_chunked(): + n, d = 10000, 100 # 4 or 8 MiB + X = rng.rand(n, d) + ind_a = rng.permutation(n) + ind_b = rng.permutation(n) + + distances = paired_euclidean_distances(X[ind_a], X[ind_b]) + distances_chunked = _paired_distances_chunked(X, ind_a, ind_b, + squared=False) + assert_array_equal(distances, distances_chunked) + + +def test_find_impostors(): + """Test if the impostors found are correct + + Create data points for class A as the 4 corners of the unit square. Create + data points for class B by shifting the points of class A, r = 4 units + along the x-axis. Using 1 nearest neighbor, all distances from samples to + their target neighbors are 2. + Impostors are in squared distance less than the squared target neighbor + distance (2**2 = 4) plus a margin of 1. Therefore, impostors are all + differently labeled samples in squared distance less than 5. In this test + only the inner most samples have one impostor each which lies on the + same y-coordinate and therefore have a squared distance of 4. + + The difference of using a sparse or dense data structure for the impostors + storage is tested in test_impostor_store(). + """ + + class_distance = 4. + X_a = np.array([[-1., 1], [-1., -1.], [1., 1.], [1., -1.]]) + X_b = X_a + np.array([class_distance, 0]) + X = np.concatenate((X_a, X_b)) + y = np.array([0, 0, 0, 0, 1, 1, 1, 1]) + lmnn = LargeMarginNearestNeighbor(n_neighbors=1) + lmnn.random_state_ = check_random_state(0) + lmnn.n_neighbors_ = 1 + n_samples = X.shape[0] + classes = np.unique(y) + target_neighbors = lmnn._select_target_neighbors_wrapper(X, y, classes) + dist_tn = row_norms(X - X[target_neighbors[:, 0]], squared=True) + dist_tn = dist_tn[:, None] + dist_tn += 1 + margin_radii = dist_tn[:, -1] + + # Groundtruth impostors + gt_impostors_mask = np.zeros((n_samples, n_samples), dtype=bool) + gt_impostors_mask[[4, 5], [2, 3]] = 1 + gt_impostors_mask += gt_impostors_mask.T + squared_dist = (class_distance - 2)**2 + gt_impostors_dist = gt_impostors_mask.astype(float) * squared_dist + + impostors_graph = lmnn._find_impostors(X, y, classes, margin_radii) + impostors_mask = impostors_graph.A > 0 + # Impostors are considered in one of two possible directions + impostors_mask += impostors_mask.T + impostors_dist = impostors_graph.A + impostors_graph.A.T + + assert_array_equal(impostors_mask, gt_impostors_mask) + assert_array_equal(impostors_dist, gt_impostors_dist) + + +def test_compute_push_loss(): + """Test if the push loss is computed correctly + + This test continues on the example from test_find_impostors. The push + loss is easy to compute, as we have only 4 violations and all of them + amount to 1 (squared distance to target neighbor + 1 - squared distance + to impostor = 4 + 1 - 4). + """ + + class_distance = 4. + X_a = np.array([[-1., 1], [-1., -1.], [1., 1.], [1., -1.]]) + X_b = X_a + np.array([class_distance, 0]) + X = np.concatenate((X_a, X_b)) + y = np.array([0, 0, 0, 0, 1, 1, 1, 1]) + lmnn = LargeMarginNearestNeighbor(n_neighbors=1) + lmnn.random_state_ = check_random_state(0) + lmnn.n_neighbors_ = 1 + classes = np.unique(y) + target_neighbors = lmnn._select_target_neighbors_wrapper(X, y, classes) + dist_tn = row_norms(X - X[target_neighbors[:, 0]], squared=True) + dist_tn = dist_tn[:, None] + dist_tn += 1 + margin_radii = dist_tn[:, -1] + impostors_graph = lmnn._find_impostors(X, y, classes, margin_radii) + loss, grad, _ = _compute_push_loss(X, target_neighbors, dist_tn, + impostors_graph) + + # The loss should be 4. (1. for each of the 4 violation) + assert loss == 4. + + +def test_loss_grad_lbfgs(): + """Test gradient of loss function + + Assert that the gradient is almost equal to its finite differences + approximation. + """ + + X, y = datasets.make_classification() + classes = np.unique(y) + L = rng.randn(rng.randint(1, X.shape[1] + 1), X.shape[1]) + lmnn = LargeMarginNearestNeighbor() + lmnn.n_neighbors_ = lmnn.n_neighbors + lmnn.n_iter_ = 0 + target_neighbors = lmnn._select_target_neighbors_wrapper(X, y, classes) + grad_static = lmnn._compute_grad_static(X, target_neighbors) + + kwargs = { + 'classes': classes, + 'target_neighbors': target_neighbors, + 'grad_static': grad_static, + 'use_sparse': False + } + + def fun(L): + return lmnn._loss_grad_lbfgs(L, X, y, **kwargs)[0] + + def grad(L): + return lmnn._loss_grad_lbfgs(L, X, y, **kwargs)[1] + + # compute relative error + rel_diff = check_grad(fun, grad, L.ravel()) / np.linalg.norm(grad(L)) + np.testing.assert_almost_equal(rel_diff, 0., decimal=5) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index e4513a62bf07e..4402c991884d1 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -1082,3 +1082,73 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): RuntimeWarning, ) return out + + +def _euclidean_distances_without_checks( + X, Y=None, Y_norm_squared=None, squared=False, X_norm_squared=None, clip=True +): + """sklearn.pairwise.euclidean_distances without checks with optional clip. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples_1, n_features) + + Y : {array-like, sparse matrix}, shape (n_samples_2, n_features) + + Y_norm_squared : array-like, shape (n_samples_2, ), optional + Pre-computed dot-products of vectors in Y (e.g., + ``(Y**2).sum(axis=1)``) + + squared : boolean, optional + Return squared Euclidean distances. + + X_norm_squared : array-like, shape = [n_samples_1], optional + Pre-computed dot-products of vectors in X (e.g., + ``(X**2).sum(axis=1)``) + + clip : bool, optional (default=True) + Whether to explicitly enforce computed distances to be non-negative. + Some algorithms, such as LargeMarginNearestNeighbor, compare distances + to strictly positive values (distances to farthest target neighbors + + margin) only to make a binary decision (if a sample is an impostor + or not). In such cases, it does not matter if the distance is zero + or negative, since it is definitely smaller than a strictly positive + value. + + Returns + ------- + distances : array, shape (n_samples_1, n_samples_2) + + """ + + if Y is None: + Y = X + + if X_norm_squared is not None: + XX = X_norm_squared + if XX.shape == (1, X.shape[0]): + XX = XX.T + else: + XX = row_norms(X, squared=True)[:, np.newaxis] + + if X is Y: # shortcut in the common case euclidean_distances(X, X) + YY = XX.T + elif Y_norm_squared is not None: + YY = np.atleast_2d(Y_norm_squared) + else: + YY = row_norms(Y, squared=True)[np.newaxis, :] + + distances = safe_sparse_dot(X, Y.T, dense_output=True) + distances *= -2 + distances += XX + distances += YY + + if clip: + np.maximum(distances, 0, out=distances) + + if X is Y: + # Ensure that distances between vectors and themselves are set to 0.0. + # This may not be the case due to floating point rounding errors. + distances.flat[:: distances.shape[0] + 1] = 0.0 + + return distances if squared else np.sqrt(distances, out=distances) diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 07a553c8cf09d..4513411d8c23e 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -32,6 +32,8 @@ from sklearn.utils.extmath import softmax from sklearn.utils.extmath import stable_cumsum from sklearn.utils.extmath import safe_sparse_dot +from sklearn.utils.extmath import _euclidean_distances_without_checks +from sklearn.metrics.pairwise import euclidean_distances from sklearn.datasets import make_low_rank_matrix, make_sparse_spd_matrix @@ -1005,3 +1007,22 @@ def test_safe_sparse_dot_dense_output(dense_output): if dense_output: expected = expected.toarray() assert_allclose_dense_sparse(actual, expected) + + +def test_euclidean_distances_without_checks(): + rng = np.random.RandomState(0) + X = rng.rand(100, 20) + Y = rng.rand(50, 20) + + # 2 matrices with no precomputed norms + distances1 = euclidean_distances(X, Y) + distances2 = _euclidean_distances_without_checks(X, Y) + + assert_array_equal(distances1, distances2) + + # 1 matrix with itself with squared row_norms precomputed and transposed + XX = row_norms(X, squared=True)[np.newaxis, :] + distances1 = euclidean_distances(X, X_norm_squared=XX) + distances2 = _euclidean_distances_without_checks(X, X_norm_squared=XX) + + assert_array_equal(distances1, distances2)