diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 8bdc4bf0..7c64418c 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -180,6 +180,10 @@ def fit(self, X, y): G, objective, total_active = self._loss_grad(X, L, dfG, k, reg, target_neighbors, label_inds) + if G is None: + # TODO: raise a warning + self.n_iter_ = 0 + return self it = 1 # we already made one iteration @@ -244,6 +248,8 @@ def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds): Ni.argmax(axis=1)[:, None], 1) impostors = self._find_impostors(furthest_neighbors.ravel(), X, label_inds, L) + if not impostors: + return None, 0, 0 g0 = _inplace_paired_L2(*Lx[impostors])