-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[WIP] Sparse output KNN #3350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Sparse output KNN #3350
Changes from all commits
40cb849
7ffd392
a41a463
e3afae1
60c04b2
7057858
2ed6bda
3b1ac31
768a613
c14c771
b07035e
6870b8d
19f8434
7c71da2
7c1fcdd
2af4597
48efa6a
c510547
d18b579
3be5bde
99c5a31
e93896c
bb1c0ae
043ed2e
844933e
e82770e
72f5cdd
e61368f
cb04d96
4d5d100
d8deae5
ccae2ae
6b16ee7
2ba9f3f
e640807
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,15 +7,20 @@ | |
# Multi-output support by Arnaud Joly <a.joly@ulg.ac.be> | ||
# | ||
# License: BSD 3 clause (C) INRIA, University of Amsterdam | ||
|
||
import array | ||
import numpy as np | ||
import scipy.sparse as sp | ||
|
||
from scipy import stats | ||
from ..utils.extmath import weighted_mode | ||
|
||
from .base import \ | ||
_check_weights, _get_weights, \ | ||
NeighborsBase, KNeighborsMixin,\ | ||
RadiusNeighborsMixin, SupervisedIntegerMixin | ||
from .base import _check_weights | ||
from .base import _get_weights | ||
from .base import NeighborsBase | ||
from .base import KNeighborsMixin | ||
from .base import RadiusNeighborsMixin | ||
from .base import SupervisedIntegerMixin | ||
|
||
from ..base import ClassifierMixin | ||
from ..utils import check_array | ||
|
||
|
@@ -146,18 +151,42 @@ def predict(self, X): | |
n_samples = X.shape[0] | ||
weights = _get_weights(neigh_dist, self.weights) | ||
|
||
y_pred = np.empty((n_samples, n_outputs), dtype=classes_[0].dtype) | ||
for k, classes_k in enumerate(classes_): | ||
if weights is None: | ||
mode, _ = stats.mode(_y[neigh_ind, k], axis=1) | ||
else: | ||
mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) | ||
if not self.sparse_target_input_: | ||
y_pred = np.empty((n_samples, n_outputs), dtype=classes_[0].dtype) | ||
for k, classes_k in enumerate(classes_): | ||
if weights is None: | ||
mode, _ = stats.mode(_y[neigh_ind, k], axis=1) | ||
else: | ||
mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) | ||
|
||
mode = np.asarray(mode.ravel(), dtype=np.intp) | ||
y_pred[:, k] = classes_k.take(mode) | ||
mode = np.asarray(mode.ravel(), dtype=np.intp) | ||
y_pred[:, k] = classes_k.take(mode) | ||
|
||
if not self.outputs_2d_: | ||
y_pred = y_pred.ravel() | ||
if not self.outputs_2d_: | ||
y_pred = y_pred.ravel() | ||
|
||
else: | ||
|
||
data = [] | ||
indices = array.array('i') | ||
indptr = array.array('i', [0]) | ||
|
||
for k, classes_k in enumerate(classes_): | ||
neigh_lbls_k = _y.getcol(k).toarray().ravel()[neigh_ind] | ||
neigh_lbls_k = classes_k[neigh_lbls_k] | ||
|
||
if weights is None: | ||
mode, _ = stats.mode(neigh_lbls_k, axis=1) | ||
else: | ||
mode, _ = weighted_mode(neigh_lbls_k, weights, axis=1) | ||
|
||
data.extend(mode[mode != 0]) | ||
indices.extend(np.where(mode != 0)[0]) | ||
indptr.append(len(indices)) | ||
|
||
y_pred = sp.csc_matrix((data, indices, indptr), | ||
There was a problem hiding this comment. Choose a reason for hiding this comment8000 The reason will be displayed to describe this comment to others. Learn more. You seem to be missing a I am also not sure that you should be requiring the data to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have corrected this, the test was changed so that it fails if the sampling from classes is not used. The dtype is now also maintained when predicting and there is an assert for this in the test. |
||
(n_samples, n_outputs), | ||
dtype=classes_[0].dtype) | ||
|
||
return y_pred | ||
|
||
|
@@ -182,6 +211,10 @@ def predict_proba(self, X): | |
|
||
classes_ = self.classes_ | ||
_y = self._y | ||
|
||
if self.sparse_target_input_: | ||
_y = _y.toarray() | ||
|
||
if not self.outputs_2d_: | ||
_y = self._y.reshape((-1, 1)) | ||
classes_ = [self.classes_] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic should probably be moved to
LabelEncoding
at some point; it currently does not handle multioutput, nor sparse (but the latter had only been used for binary targets until now). The sparse implementation is not explicitly tested, and some of its conditions are only being tested because of the random number generation happening to produce entirely dense and non-dense columns.