8000 [MRG+1] BUG Fix the shrinkage implementation in NearestCentroid (#9219) · musically-ut/scikit-learn@90607f1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 90607f1

Browse files
qinhanmin2014amueller
authored andcommitted
[MRG+1] BUG Fix the shrinkage implementation in NearestCentroid (scikit-learn#9219)
* fix the shrinkage implementation * update function name * update what's new * update what's new * spelling * confict fix * conflict fix
1 parent 8811d59 commit 90607f1

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ random sampling procedures.
6363
- :class:`linear_model.LassoLars` (bug fix)
6464
- :class:`linear_model.LassoLarsIC` (bug fix)
6565
- :class:`manifold.TSNE` (bug fix)
66+
- :class:`neighbors.NearestCentroid` (bug fix)
6667
- :class:`semi_supervised.LabelSpreading` (bug fix)
6768
- :class:`semi_supervised.LabelPropagation` (bug fix)
6869
- tree based models where ``min_weight_fraction_leaf`` is used (enhancement)
@@ -536,6 +537,9 @@ Decomposition, manifold learning and clustering
536537
- Fix bug where :mod:`mixture` ``sample`` methods did not return as many
537538
samples as requested. :issue:`7702` by :user:`Levi John Wolf <ljwolf>`.
538539

540+
- Fixed the shrinkage implementation in :class:`neighbors.NearestCentroid`.
541+
:issue:`9219` by `Hanmin Qin <https://github.com/qinhanmin2014>`_.
542+
539543
Preprocessing and feature selection
540544

541545
- For sparse matrices, :func:`preprocessing.normalize` with ``return_norm=True``

sklearn/neighbors/nearest_centroid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def fit(self, X, y):
147147
dataset_centroid_ = np.mean(X, axis=0)
148148

149149
# m parameter for determining deviation
150-
m = np.sqrt((1. / nk) + (1. / n_samples))
150+
m = np.sqrt((1. / nk) - (1. / n_samples))
151151
# Calculate deviation using the standard deviation of centroids.
152152
variance = (X - self.centroids_[y_ind]) ** 2
153153
variance = variance.sum(axis=0)

sklearn/neighbors/tests/test_nearest_centroid.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def test_pickle():
9797
" after pickling (classification).")
9898

9999

100+
def test_shrinkage_correct():
101+
# Ensure that the shrinking is correct.
102+
# The expected result is calculated by R (pamr),
103+
# which is implemented by the author of the original paper.
104+
# (One need to modify the code to output the new centroid in pamr.predict)
105+
106+
X = np.array([[0, 1], [1, 0], [1, 1], [2, 0], [6, 8]])
107+
y = np.array([1, 1, 2, 2, 2])
108+
clf = NearestCentroid(shrink_threshold=0.1)
109+
clf.fit(X, y)
110+
expected_result = np.array([[0.7787310, 0.8545292], [2.814179, 2.763647]])
111+
np.testing.assert_array_almost_equal(clf.centroids_, expected_result)
112+
113+
100114
def test_shrinkage_threshold_decoded_y():
101115
clf = NearestCentroid(shrink_threshold=0.01)
102116
y_ind = np.asarray(y)

0 commit comments

Comments
 (0)
0