8000 MNT use `scipy.sparse.csgraph.shortest_path` in Isomap (#20531) · rth/scikit-learn@81165ca · GitHub
[go: up one dir, main page]

Skip to content

Commit 81165ca

Browse files
TomDLTglemaitre
andauthored
MNT use scipy.sparse.csgraph.shortest_path in Isomap (scikit-learn#20531)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 6b2d5a9 commit 81165ca

File tree

9 files changed

+290
-642
lines changed

9 files changed

+290
-642
lines changed

doc/whats_new/v1.0.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ random sampling procedures.
4343
- |Fix| :class:`manifold.TSNE` now avoids numerical underflow issues during
4444
affinity matrix computation.
4545

46+
- |Fix| :class:`manifold.Isomap` now connects disconnected components of the
47+
neighbors graph along some minimum distance pairs, instead of changing
48+
every infinite distances to zero.
49+
4650
- |Fix| The splitting criterion of :class:`tree.DecisionTreeClassifier` and
4751
:class:`tree.DecisionTreeRegressor` can be impacted by a fix in the handling
4852
of rounding errors. Previously some extra spurious splits could occur.
@@ -505,6 +509,12 @@ Changelog
505509
be scaled to have standard deviation 1e-4 in 1.2.
506510
:pr:`19491` by :user:`Dmitry Kobak <dkobak>`.
507511

512+
- |Fix| :class:`manifold.Isomap` now uses `scipy.sparse.csgraph.shortest_path`
513+
to compute the graph shortest path. It also connects disconnected components
514+
of the neighbors graph along some minimum distance pairs, instead of changing
515+
every infinite distances to zero. :pr:`20531` by `Roman Yurchak`_ and `Tom
516+
Dupre la Tour`_.
517+
508518
:mod:`sklearn.metrics`
509519
......................
510520

@@ -721,6 +731,10 @@ Changelog
721731
these functions were not documented and part from the public API.
722732
:pr:`20521` by :user:`Olivier Grisel <ogrisel>`.
723733

734+
- |API| Fixed several bugs in :func:`utils.graph.graph_shortest_path`, which is
735+
now deprecated. Use `scipy.sparse.csgraph.shortest_path` instead. :pr:`20531`
736+
by `Tom Dupre la Tour`_.
737+
724738
:mod:`sklearn.validation`
725739
.........................
726740

sklearn/cluster/_agglomerative.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from scipy.sparse.csgraph import connected_components
1616

1717
from ..base import BaseEstimator, ClusterMixin
18-
from ..metrics.pairwise import paired_distances, pairwise_distances
18+
from ..metrics.pairwise import paired_distances
1919
from ..neighbors import DistanceMetric
2020
from ..neighbors._dist_metrics import METRIC_MAPPING
2121
from ..utils import check_array
2222
from ..utils._fast_dict import IntFloatDict
2323
from ..utils.fixes import _astype_copy_false
24+
from ..utils.graph import _fix_connected_components
2425
from ..utils.validation import check_memory
2526

2627
# mypy error: Module 'sklearn.cluster' has no attribute '_hierarchical_fast'
@@ -68,21 +69,14 @@ def _fix_connectivity(X, connectivity, affinity):
6869
stacklevel=2,
6970
)
7071
# XXX: Can we do without completing the matrix?
71-
for i in range(n_connected_components):
72-
idx_i = np.where(labels == i)[0]
73-
Xi = X[idx_i]
74-
for j in range(i):
75-
idx_j = np.where(labels == j)[0]
76-
Xj = X[idx_j]
77-
if affinity == "precomputed":
78-
D = X[np.ix_(idx_i, idx_j)]
79-
else:
80-
D = pairwise_distances(Xi, Xj, metric=affinity)
81-
ii, jj = np.where(D == np.min(D))
82-
ii = ii[0]
83-
jj = jj[0]
84-
connectivity[idx_i[ii], idx_j[jj]] = True
85-
connectivity[idx_j[jj], idx_i[ii]] = True
72+
connectivity = _fix_connected_components(
73+
X=X,
74+
graph=connectivity,
75+
n_connected_components=n_connected_components,
76+
component_labels=labels,
77+
metric=affinity,
78+
mode="connectivity",
79+
)
8680

8781
return connectivity, n_connected_components
8882

@@ -661,7 +655,6 @@ def _single_linkage(*args, **kwargs):
661655
single=_single_linkage,
662656
)
663657

664-
665658
###############################################################################
666659
# Functions for cutting hierarchical clustering tree
667660

sklearn/manifold/_isomap.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22

33
# Author: Jake Vanderplas -- <vanderplas@astro.washington.edu>
44
# License: BSD 3 clause (C) 2011
5+
import warnings
56

67
import numpy as np
8+
import scipy
9+
from scipy.sparse.csgraph import shortest_path
10+
from scipy.sparse.csgraph import connected_components
11+
712
from ..base import BaseEstimator, TransformerMixin
813
from ..neighbors import NearestNeighbors, kneighbors_graph
914
from ..utils.validation import check_is_fitted
10-
from ..utils.graph import graph_shortest_path
1115
from ..decomposition import KernelPCA
1216
from ..preprocessing import KernelCenterer
17+
from ..utils.graph import _fix_connected_components
18+
from ..externals._packaging.version import parse as parse_version
1319

1420

1521
class Isomap(TransformerMixin, BaseEstimator):
@@ -186,9 +192,46 @@ def _fit_transform(self, X):
186192
n_jobs=self.n_jobs,
187193
)
188194

189-
self.dist_matrix_ = graph_shortest_path(
190-
kng, method=self.path_method, directed=False
191-
)
195+
# Compute the number of connected components, and connect the different
196+
# components to be able to compute a shortest path between all pairs
197+
# of samples in the graph.
198+
# Similar fix to cluster._agglomerative._fix_connectivity.
199+
n_connected_components, labels = connected_components(kng)
200+
if n_connected_components > 1:
201+
if self.metric == "precomputed":
202+
raise RuntimeError(
203+
"The number of connected components of the neighbors graph"
204+
f" is {n_connected_components} > 1. The graph cannot be "
205+
"completed with metric='precomputed', and Isomap cannot be"
206+
"fitted. Increase the number of neighbors to avoid this "
207+
"issue."
208+
)
209+
warnings.warn(
210+
"The number of connected components of the neighbors graph "
211+
f"is {n_connected_components} > 1. Completing the graph to fit"
212+
" Isomap might be slow. Increase the number of neighbors to "
213+
"avoid this issue.",
214+
stacklevel=2,
215+
)
216+
217+
# use array validated by NearestNeighbors
218+
kng = _fix_connected_components(
219+
X=self.nbrs_._fit_X,
220+
graph=kng,
221+
n_connected_components=n_connected_components,
222+
component_labels=labels,
223+
mode="distance",
224+
metric=self.nbrs_.effective_metric_,
225+
**self.nbrs_.effective_metric_params_,
226+
)
227+
228+
if parse_version(scipy.__version__) < parse_version("1.3.2"):
229+
# make identical samples have a nonzero distance, to account for
230+
# issues in old scipy Floyd-Warshall implementation.
231+
kng.data += 1e-15
232+
233+
self.dist_matrix_ = shortest_path(kng, method=self.path_method, directed=False)
234+
192235
G = self.dist_matrix_ ** 2
193236
G *= -0.5
194237

sklearn/manifold/tests/test_isomap.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_transform():
9494
X, y = datasets.make_s_curve(n_samples, random_state=0)
9595

9696
# Compute isomap embedding
97-
iso = manifold.Isomap(n_components=n_components, n_neighbors=2)
97+
iso = manifold.Isomap(n_components=n_components)
9898
X_iso = iso.fit_transform(X)
9999

100100
# Re-embed a noisy version of the points
@@ -190,6 +190,25 @@ def test_sparse_input():
190190
for eigen_solver in eigen_solvers:
191191
for path_method in path_methods:
192192
clf = manifold.Isomap(
193-
n_components=2, eigen_solver=eigen_solver, path_method=path_method
193+
n_components=2,
194+
eigen_solver=eigen_solver,
195+
path_method=path_method,
196+
n_neighbors=8,
194197
)
195198
clf.fit(X)
199+
200+
201+
def test_multiple_connected_components():
202+
# Test that a warning is raised when the graph has multiple components
203+
X = np.array([0, 1, 2, 5, 6, 7])[:, None]
204+
with pytest.warns(UserWarning, match="number of connected components"):
205+
manifold.Isomap(n_neighbors=2).fit(X)
206+
207+
208+
def test_multiple_connected_components_metric_precomputed():
209+
# Test that an error is raised when the graph has multiple components
210+
# and when the metric is "precomputed".
211+
X = np.array([0, 1, 2, 5, 6, 7])[:, None]
212+
X_graph = neighbors.kneighbors_graph(X, n_neighbors=2, mode="distance")
213+
with pytest.raises(RuntimeError, match="number of connected components"):
214+
manifold.Isomap(n_neighbors=1, metric="precomputed").fit(X_graph)

sklearn/utils/graph.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
# Jake Vanderplas <vanderplas@astro.washington.edu>
1111
# License: BSD 3 clause
1212

13+
import numpy as np
1314
from scipy import sparse
1415

15-
from .graph_shortest_path import graph_shortest_path # noqa
16+
from .deprecation import deprecated
17+
from ..metrics.pairwise import pairwise_distances
1618

1719

1820
###############################################################################
@@ -67,3 +69,130 @@ def single_source_shortest_path_length(graph, source, *, cutoff=None):
6769
break
6870
level += 1
6971
return seen # return all path lengths as dictionary
72+
73+
74+
@deprecated(
75+
"`graph_shortest_path` is deprecated in 1.0 (renaming of 0.25) and will "
76+
"be removed in 1.2. Use `scipy.sparse.csgraph.shortest_path` instead."
77+
)
78+
def graph_shortest_path(dist_matrix, directed=True, method="auto"):
79+
"""Shortest-path graph search on a positive directed or undirected graph.
80+
81+
Parameters
82+
----------
83+
dist_matrix : arraylike or sparse matrix, shape = (N,N)
84+
Array of positive distances.
85+
If vertex i is connected to vertex j, then dist_matrix[i,j] gives
86+
the distance between the vertices.
87+
If vertex i is not connected to vertex j, then dist_matrix[i,j] = 0
88+
89+
directed : boolean
90+
if True, then find the shortest path on a directed graph: only
91+
progress from a point to its neighbors, not the other way around.
92+
if False, then find the shortest path on an undirected graph: the
93+
algorithm can progress from a point to its neighbors and vice versa.
94+
95+
method : string ['auto'|'FW'|'D']
96+
method to u 10000 se. Options are
97+
'auto' : attempt to choose the best method for the current problem
98+
'FW' : Floyd-Warshall algorithm. O[N^3]
99+
'D' : Dijkstra's algorithm with Fibonacci stacks. O[(k+log(N))N^2]
100+
101+
Returns
102+
-------
103+
G : np.ndarray, float, shape = [N,N]
104+
G[i,j] gives the shortest distance from point i to point j
105+
along the graph.
106+
107+
Notes
108+
-----
109+
As currently implemented, Dijkstra's algorithm does not work for
110+
graphs with direction-dependent distances when directed == False.
111+
i.e., if dist_matrix[i,j] and dist_matrix[j,i] are not equal and
112+
both are nonzero, method='D' will not necessarily yield the correct
113+
result.
114+
Also, these routines have not been tested for graphs with negative
115+
distances. Negative distances can lead to infinite cycles that must
116+
be handled by specialized algorithms.
117+
"""
118+
return sparse.csgraph.shortest_path(dist_matrix, method=method, directed=directed)
119+
120+
121+
def _fix_connected_components(
122+
X,
123+
graph,
124+
n_connected_components,
125+
component_labels,
126+
mode="distance",
127+
metric="euclidean",
128+
**kwargs,
129+
):
130+
"""Add connections to sparse graph to connect unconnected components.
131+
132+
For each pair of unconnected components, compute all pairwise distances
133+
from one component to the other, and add a connection on the closest pair
134+
of samples. This is a hacky way to get a graph with a single connected
135+
component, which is necessary for example to compute a shortest path
136+
between all pairs of samples in the graph.
137+
138+
Parameters
139+
----------
140+
X : array of shape (n_samples, n_features) or (n_samples, n_samples)
141+
Features to compute the pairwise distances. If `metric =
10000
142+
"precomputed"`, X is the matrix of pairwise distances.
143+
144+
graph : sparse matrix of shape (n_samples, n_samples)
145+
Graph of connection between samples.
146+
147+
n_connected_components : int
148+
Number of connected components, as computed by
149+
`scipy.sparse.csgraph.connected_components`.
150+
151+
component_labels : array of shape (n_samples)
152+
Labels of connected components, as computed by
153+
`scipy.sparse.csgraph.connected_components`.
154+
155+
mode : {'connectivity', 'distance'}, default='distance'
156+
Type of graph matrix: 'connectivity' corresponds to the connectivity
157+
matrix with ones and zeros, and 'distance' corresponds to the distances
158+
between neighbors according to the given metric.
159+
160+
metric : str
161+
Metric used in `sklearn.metrics.pairwise.pairwise_distances`.
162+
163+
kwargs : kwargs
164+
Keyword arguments passed to
165+
`sklearn.metrics.pairwise.pairwise_distances`.
166+
167+
Returns
168+
-------
169+
graph : sparse matrix of shape (n_samples, n_samples)
170+
Graph of connection between samples, with a single connected component.
171+
"""
172+
173+
for i in range(n_connected_components):
174+
idx_i = np.flatnonzero(component_labels == i)
175+
Xi = X[idx_i]
176+
for j in range(i):
177+
idx_j = np.flatnonzero(component_labels == j)
178+
Xj = X[idx_j]
179+
180+
if metric == "precomputed":
181+
D = X[np.ix_(idx_i, idx_j)]
182+
else:
183+
D = pairwise_distances(Xi, Xj, metric=metric, **kwargs)
184+
185+
ii, jj = np.unravel_index(D.argmin(axis=None), D.shape)
186+
if mode == "connectivity":
187+
graph[idx_i[ii], idx_j[jj]] = 1
188+
graph[idx_j[jj], idx_i[ii]] = 1
189+
elif mode == "distance":
190+
graph[idx_i[ii], idx_j[jj]] = D[ii, jj]
191+
graph[idx_j[jj], idx_i[ii]] = D[ii, jj]
192+
else:
193+
raise ValueError(
194+
"Unknown mode=%r, should be one of ['connectivity', 'distance']."
195+
% mode
196+
)
197+
198+
return graph

0 commit comments

Comments
 (0)
0