-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
FEAT Large Margin Nearest Neighbor implementation #8602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b016b70
e233427
d325d6d
f51641a
81594c5
6444d66
81fd15a
6dc7829
d38fb0b
6a08089
a840e15
e26c0ca
62e4413
08b0f08
197bfdd
617eaae
13db483
1e58e83
942c2a5
191d9da
eef115e
ca2dad8
f83a358
d42e05b
18a03f5
27a1980
3bfbc4c
c594a69
b89aa5f
40f477c
cd4b4e5
f3b30c1
9c12cfd
dd6f3c9
1911f05
00af167
de1ff46
b6a88f7
2a4c8c5
a302532
e2bb486
8149c4c
24a87c4
38b727d
3c79c3b
4b62e7c
291e1d3
a6e1778
226bb7a
9b72731
247bff7
9c211b4
2be4d00
ad8c8c7
46d63c8
c0bec16
bc2548a
aa50245
4cedbcb
3b55816
0c7188c
35a156b
8eb30cd
7370ac2
b795acc
09102c0
4ae2e1e
44cd16f
62ee31b
62cadeb
cd78724
cb1aeea
aff5d08
26b1787
40357cc
8bbb484
0b41d73
68a0261
0a9d6f1
06317b2
5f7e66e
9ff8619
27849ae
eada78f
8f9c848
692044e
564a973
9d28498
4de9082
91fba3f
fee84b6
c9e678b
87bf2d0
276a096
c907d3d
1f9b8fd
826c766
ed00abf
00b2d69
e395d36
92a97f2
ca5b980
6b95ede
caaf5c0
4eb156c
a472941
a7820c3
5114a39
c1d9f65
2be1c01
5554c5a
a44d21b
84522a1
938e382
3322d97
f0d21f8
3990888
7a68665
98d4bc5
8c2bf26
2b6be78
d8351f6
c76cf2b
f6f520f
5205e6b
c5a7f19
ea36c85
7125640
cfcee16
a367478
2c29c0a
66ed49e
f26d827
b1de00b
b741062
0235d27
f0eb7e7
3893dd2
2865ef1
c9f5d50
01e1004
9a47e5a
3db67b4
9a9ca4a
e22658b
5d980ed
5909a80
8206ee1
4d48370
97bc63a
2993572
1164a03
444aff7
ce67781
ad4746d
c8c6172
da2dedb
c98232f
c8eb3ac
8670d8d
23bfd07
476d76f
44636b6
5337752
e35d78e
eade527
4b3cf21
17abaeb
8ddcafe
8dd6cd4
8d8c3d6
57eb704
7afa6bb
7e68321
97d0b49
5342c2f
467e896
c5e4659
8649035
981d139
78ee381
5647432
7dc57de
90c53b6
294d84a
f6ac9b8
b37c02b
f71d01e
c378e15
6100c26
c2d4a9e
d3de965
cc6eafc
f8c8005
d10a90e
dd257a1
87cde6b
835a86b
316b122
0085610
4d7559c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -843,3 +843,214 @@ added space complexity in the operation. | |
|
||
`Wikipedia entry on Neighborhood Components Analysis | ||
<https://en.wikipedia.org/wiki/Neighbourhood_components_analysis>`_ | ||
|
||
|
||
.. _lmnn: | ||
|
||
Large Margin Nearest Neighbor | ||
============================= | ||
|
||
.. sectionauthor:: John Chiotellis <johnyc.code@gmail.com> | ||
|
||
Large Margin Nearest Neighbor (LMNN, :class:`LargeMarginNearestNeighbor`) is | ||
a metric learning algorithm which aims to improve the accuracy of | ||
nearest neighbors classification compared to the standard Euclidean distance. | ||
|
||
.. |lmnn_illustration_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_illustration_001.png | ||
:target: ../auto_examples/neighbors/plot_lmnn_illustration.html | ||
:scale: 50 | ||
|
||
.. |lmnn_illustration_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_illustration_002.png | ||
:target: ../auto_examples/neighbors/plot_lmnn_illustration.html | ||
:scale: 50 | ||
|
||
.. centered:: |lmnn_illustration_1| |lmnn_illustration_2| | ||
|
||
|
||
For each training sample, the algorithm fixes :math:`k` "target neighbors", | ||
namely the :math:`k`-nearest training samples (as measured by the Euclidean | ||
distance) that share the same label. Given these target neighbors, LMNN | ||
learns a linear transformation of the data by optimizing a trade-off between | ||
two goals. The first one is to make each (transformed) point closer to its | ||
target neighbors than to any differently-labeled point by a large margin, | ||
thereby enclosing the target neighbors in a sphere around the reference | ||
sample. Data samples from different classes that violate this margin are | ||
called "impostors". The second goal is to minimize the distances of each | ||
sample to its target neighbors, which can be seen as a form of regularization. | ||
|
||
Classification | ||
-------------- | ||
|
||
Combined with a nearest neighbors classifier (:class:`KNeighborsClassifier`), | ||
this method is attractive for classification because it can naturally | ||
handle multi-class problems without any increase in the model size, and only | ||
a single parameter (``n_neighbors``) has to be selected by the user before | ||
training. | ||
|
||
Large Margin Nearest Neighbor classification has been shown to work well in | ||
practice for data sets of varying size and difficulty. In contrast to | ||
related methods such as Linear Discriminant Analysis, LMNN does not make any | ||
assumptions about the class distributions. The nearest neighbor classification | ||
can naturally produce highly irregular decision boundaries. | ||
|
||
To use this model for classification, one needs to combine a :class:`LargeMarginNearestNeighbor` | ||
instance that learns the optimal transformation with a :class:`KNeighborsClassifier` | ||
instance that performs the classification in the embedded space. Here is an | ||
example using the two classes: | ||
|
||
>>> from sklearn.neighbors import LargeMarginNearestNeighbor | ||
>>> from sklearn.neighbors import KNeighborsClassifier | ||
>>> from sklearn.datasets import load_iris | ||
>>> from sklearn.model_selection import train_test_split | ||
>>> X, y = load_iris(return_X_y=True) | ||
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, | ||
... stratify=y, test_size=0.7, random_state=42) | ||
>>> lmnn = LargeMarginNearestNeighbor(n_neighbors=3, random_state=42) | ||
>>> lmnn.fit(X_train, y_train) | ||
LargeMarginNearestNeighbor(...) | ||
>>> # Apply the learned transformation when using KNeighborsClassifier | ||
>>> knn = KNeighborsClassifier(n_neighbors=3) | ||
>>> knn.fit(lmnn.transform(X_train), y_train) | ||
KNeighborsClassifier(...) | ||
>>> print(knn.score(lmnn.transform(X_test), y_test)) | ||
0.971428... | ||
|
||
Alternatively, one can create a :class:`sklearn.pipeline.Pipeline` instance | ||
that automatically applies the transformation when fitting or predicting: | ||
|
||
>>> from sklearn.pipeline import Pipeline | ||
>>> lmnn = LargeMarginNearestNeighbor(n_neighbors=3, random_state=42) | ||
>>> knn = KNeighborsClassifier(n_neighbors=3) | ||
>>> lmnn_pipe = Pipeline([('lmnn', lmnn), ('knn', knn)]) | ||
>>> lmnn_pipe.fit(X_train, y_train) | ||
Pipeline(...) | ||
>>> print(lmnn_pipe.score(X_test, y_test)) | ||
0.971428... | ||
|
||
.. |lmnn_classification_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_classification_001.png | ||
:target: ../auto_examples/neighbors/plot_lmnn_classification.html | ||
:scale: 50 | ||
|
||
.. |lmnn_classification_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_classification_002.png | ||
:target: ../auto_examples/neighbors/plot_lmnn_classification.html | ||
:scale: 50 | ||
|
||
.. centered:: |lmnn_classification_1| |lmnn_classification_2| | ||
|
||
|
||
The plot shows decision boundaries for nearest neighbor classification and | ||
large margin nearest neighbor classification. | ||
|
||
.. _lmnn_dim_reduction: | ||
|
||
Dimensionality reduction | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be referenced in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ...should be in sklearn/decomposition? I thing decomposition implies unsupervised methods, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just mention it there, but I'm a bit ambivalent. Often people use decomposition despite available supervision, so informing them of supervised alternatives seems helpful There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar cases probably include linear discrimant analysis, CCA, PLS, principal component regression, ... |
||
------------------------ | ||
|
||
:class:`LargeMarginNearestNeighbor` can be used to perform supervised | ||
dimensionality reduction. The input data are mapped to a linear subspace | ||
consisting of the directions which minimize the LMNN objective. Unlike | ||
unsupervised methods which aim to maximize the uncorrelatedness (PCA) or even | ||
independence (ICA) of the components, LMNN aims to find components that | ||
maximize the nearest neighbors classification accuracy of the transformed | ||
inputs. The desired output dimensionality can be set using the parameter | ||
``n_components``. For instance, the following shows a comparison of | ||
dimensionality reduction with Principal Component Analysis (:class:`sklearn | ||
.decomposition.PCA`), Linear Discriminant Analysis (:class:`sklearn | ||
.discriminant_analysis.LinearDiscriminantAnalysis`) and Large Margin Nearest | ||
Neighbor (:class:`LargeMarginNearestNeighbor`) on the Olivetti dataset, a | ||
dataset with size :math:`n_{samples} = 400` and :math:`n_{features} = 64 \times 64 = 4096`. | ||
The data set is splitted in a training and test set of equal size. For | ||
evaluation the 3-nearest neighbor classification accuracy is computed on the | ||
2-dimensional embedding found by each method. Each data sample belongs to one | ||
of 40 classes. | ||
|
||
.. |lmnn_dim_reduction_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_dim_reduction_001.png | ||
:target: ../auto_examples/neighbors/plot_lmnn_dim_reduction.html | ||
:width: 32% | ||
|
||
.. |lmnn_dim_reduction_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_dim_reduction_002.png | ||
:target: ../auto_examples/neighbors/plot_lmnn_dim_reduction.html | ||
:width: 32% | ||
|
||
.. |lmnn_dim_reduction_3| image:: ../auto_examples/neighbors/images/sphx_glr_plot_lmnn_dim_reduction_003.png | ||
:target: ../auto_examples/neighbors/plot_lmnn_dim_reduction.html | ||
:width: 32% | ||
|
||
.. centered:: |lmnn_dim_reduction_1| |lmnn_dim_reduction_2| |lmnn_dim_reduction_3| | ||
|
||
|
||
Mathematical formulation | ||
------------------------ | ||
|
||
LMNN learns a linear transformation matrix :math:`L` of | ||
size ``(n_components, n_features)``. The objective function consists of | ||
two competing terms, the pull loss that pulls target neighbors closer to | ||
their reference sample and the push loss that pushes impostors away: | ||
|
||
.. math:: | ||
\varepsilon_{\text{pull}} (L) = \sum_{i, j \rightsquigarrow i} ||L(x_i - x_j)||^2, | ||
.. math:: | ||
\varepsilon_{\text{push}} (L) = \sum_{i, j \rightsquigarrow i} | ||
\sum_{l} (1 - y_{il}) [1 + || L(x_i - x_j)||^2 - || L | ||
(x_i - x_l)||^2]_+, | ||
|
||
where :math:`y_{il} = 1` if :math:`y_i = y_l` and :math:`0` otherwise, | ||
:math:`[x]_+ = \max(0, x)` is the hinge loss, and :math:`j \rightsquigarrow i` | ||
means that the :math:`j^{th}` sample is a target neighbor of the | ||
:math:`i^{th}` sample. | ||
|
||
LMNN solves the following (nonconvex) minimization problem: | ||
|
||
.. math:: | ||
\min_L \varepsilon(L) = (1 - \mu) \varepsilon_{\text{pull}} (L) + | ||
\mu \varepsilon_{\text{push}} (L) \text{, } \quad \mu \in [0,1]. | ||
|
||
The parameter :math:`\mu` (``push_loss_weight``) calibrates the trade-off | ||
between penalizing large distances to target neighbors and penalizing margin | ||
violations by impostors. In practice, the two terms are usually weighted | ||
equally (:math:`\mu = 0.5`). | ||
|
||
|
||
Mahalanobis distance | ||
^^^^^^^^^^^^^^^^^^^^ | ||
|
||
LMNN can be seen as learning a (squared) Mahalanobis distance metric: | ||
|
||
.. math:: | ||
|| L(x_i - x_j)||^2 = (x_i - x_j)^TM(x_i - x_j), | ||
|
||
where :math:`M = L^T L` is a symmetric positive semi-definite matrix of size | ||
``(n_features, n_features)``. The objective function of LMNN can be | ||
rewritten and solved with respect to :math:`M` directly. This results in a | ||
convex but constrained problem (since :math:`M` must be symmetric positive | ||
semi-definite). See the journal paper in the References for more details. | ||
|
||
|
||
Implementation | ||
-------------- | ||
|
||
This implementation follows closely the MATLAB implementation found at | ||
https://bitbucket.org/mlcircus/lmnn which solves the unconstrained problem. | ||
It finds a linear transformation :math:`L` by optimization with L-BFGS instead | ||
of solving the constrained problem that finds the globally optimal distance | ||
metric. Different from the paper, the problem solved by this implementation is | ||
with the *squared* hinge loss (to make the problem differentiable). | ||
|
||
See the examples below and the doc string of :meth:`LargeMarginNearestNeighbor.fit` | ||
for further information. | ||
|
||
.. topic:: Examples: | ||
|
||
* :ref:`sphx_glr_auto_examples_neighbors_plot_lmnn_classification.py` | ||
* :ref:`sphx_glr_auto_examples_neighbors_plot_lmnn_dim_reduction.py` | ||
|
||
|
||
.. topic:: References: | ||
|
||
* `"Distance Metric Learning for Large Margin Nearest Neighbor Classification" | ||
<http://jmlr.csail.mit.edu/papers/volume10/weinberger09a/weinberger09a.pdf>`_, | ||
Weinberger, Kilian Q., and Lawrence K. Saul, Journal of Machine Learning | ||
Research, Vol. 10, Feb. 2009, pp. 207-244. | ||
|
||
* `Wikipedia entry on Large Margin Nearest Neighbor | ||
<https://en.wikipedia.org/wiki/Large_margin_nearest_neighbor>`_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
========================================================================== | ||
Comparing Nearest Neighbors with and without Large Margin Nearest Neighbor | ||
========================================================================== | ||
|
||
This example compares nearest neighbors classification with and without | ||
Large Margin Nearest Neighbor. | ||
|
||
It will plot the decision boundaries for each class determined by a simple | ||
Nearest Neighbors classifier against the decision boundaries determined by a | ||
Large Margin Nearest Neighbor classifier. The latter aims to find a distance | ||
metric that maximizes the nearest neighbor classification accuracy on a given | ||
training set. | ||
""" | ||
|
||
# Author: John Chiotellis <johnyc.code@gmail.com> | ||
# License: BSD 3 clause | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from matplotlib.colors import ListedColormap | ||
from sklearn import datasets | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.neighbors import KNeighborsClassifier, LargeMarginNearestNeighbor | ||
from sklearn.pipeline import Pipeline | ||
|
||
|
||
print(__doc__) | ||
|
||
n_neighbors = 3 | ||
|
||
# import some data to play with | ||
iris = datasets.load_iris() | ||
|
||
# we only take the first two features. We could avoid this ugly | ||
# slicing by using a two-dim dataset | ||
X = iris.data[:, :2] | ||
y = iris.target | ||
|
||
X_train, X_test, y_train, y_test = \ | ||
train_test_split(X, y, stratify=y, test_size=0.7, random_state=42) | ||
|
||
h = .01 # step size in the mesh | ||
|
||
# Create color maps | ||
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) | ||
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) | ||
|
||
names = ['K-Nearest Neighbors', 'Large Margin Nearest Neighbor'] | ||
|
||
classifiers = [KNeighborsClassifier(n_neighbors=n_neighbors), | ||
Pipeline([('lmnn', LargeMarginNearestNeighbor( | ||
n_neighbors=n_neighbors, random_state=42)), | ||
('knn', KNeighborsClassifier(n_neighbors)) | ||
]) | ||
] | ||
|
||
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 | ||
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 | ||
D959 | xx, yy = np.meshgrid(np.arange(x_min, x_max, h), | |
np.arange(y_min, y_max, h)) | ||
|
||
for name, clf in zip(names, classifiers): | ||
|
||
clf.fit(X_train, y_train) | ||
score = clf.score(X_test, y_test) | ||
|
||
# Plot the decision boundary. For that, we will assign a color to each | ||
# point in the mesh [x_min, x_max]x[y_min, y_max]. | ||
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) | ||
|
||
# Put the result into a color plot | ||
Z = Z.reshape(xx.shape) | ||
plt.figure() | ||
plt.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=.8, shading='auto') | ||
|
||
# Plot also the training and testing points | ||
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor='k', s=20) | ||
plt.xlim(xx.min(), xx.max()) | ||
plt.ylim(yy.min(), yy.max()) | ||
plt.title("{} (k = {})".format(name, n_neighbors)) | ||
plt.text(0.9, 0.1, '{:.2f}'.format(score), size=15, | ||
ha='center', va='center', transform=plt.gca().transAxes) | ||
|
||
plt.show() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that the example should demo the use of the transform method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could show the feature space before and after learning the metric There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I can do that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Building on Gael's comment, it would be nice to somehow plot a form of the linear subspace for the sake of intuitive appeal to the user... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this statement here is fully bullet-proof. As a memory-based algorithm, kNN-type procedure use the training data as 'model', which is why the notion of 'model size' may be less immediately clear in this context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, but even in that case, it holds that nothing needs to be modified in the algorithm to handle multi-class problems (in contrast to e.g. SVMs). Should we change maybe
any increase in the model size
toany change in the algorithm
. I thing this paragraph was by @bellet ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, the idea is that the model size of LMNN (as well as kNN) does not depend on the number of classes. In the sense that you can have as many labels as you want in the training data, the number of parameters to learn in LMNN remains the same (and the model size/complexity of kNN remains the same for fixed training set size).