8000 [MRG+2] faster way of computing means across each group (#10020) · scikit-learn/scikit-learn@555bf6b · GitHub
[go: up one dir, main page]

Skip to content

Commit 555bf6b

Browse files
sergulaydoreagramfort
authored andcommitted
[MRG+2] faster way of computing means across each group (#10020)
1 parent abb43c1 commit 555bf6b

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

sklearn/cluster/_feature_agglomeration.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from ..base import TransformerMixin
1111
from ..utils import check_array
1212
from ..utils.validation import check_is_fitted
13+
from scipy.sparse import issparse
1314

1415
###############################################################################
1516
# Mixin class for feature agglomeration.
1617

18+
1719
class AgglomerationTransform(TransformerMixin):
1820
"""
1921
A class for feature agglomeration via the transform interface
@@ -40,14 +42,21 @@ def transform(self, X):
4042

4143
pooling_func = self.pooling_func
4244
X = check_array(X)
43-
nX = []
4445
if len(self.labels_) != X.shape[1]:
4546
raise ValueError("X has a different number of features than "
4647
"during fitting.")
47-
48-
for l in np.unique(self.labels_):
49-
nX.append(pooling_func(X[:, self.labels_ == l], axis=1))
50-
return np.array(nX).T
48+
if pooling_func == np.mean and not issparse(X):
49+
size = np.bincount(self.labels_)
50+
n_samples = X.shape[0]
51+
# a fast way to compute the mean of grouped features
52+
nX = np.array([np.bincount(self.labels_, X[i, :]) / size
53+
for i in range(n_samples)])
54+
else:
55+
nX = []
56+
for l in np.unique(self.labels_):
57+
nX.append(pooling_func(X[:, self.labels_ == l], axis=1))
58+
nX = np.array(nX).T
59+
return nX
5160

5261
def inverse_transform(self, Xred):
5362
"""
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Tests for sklearn.cluster._feature_agglomeration
3+
"""
4+
# Authors: Sergul Aydore 2017
5+
import numpy as np
6+
from sklearn.cluster import FeatureAgglomeration
7+
from sklearn.utils.testing import assert_true
8+
from sklearn.utils.testing import assert_array_almost_equal
9+
10+
11+
def test_feature_agglomeration():
12+
n_clusters = 1
13+
X = np.array([0, 0, 1]).reshape(1, 3) # (n_samples, n_features)
14+
15+
agglo_mean = FeatureAgglomeration(n_clusters=n_clusters,
16+
pooling_func=np.mean)
17+
agglo_median = FeatureAgglomeration(n_clusters=n_clusters,
18+
pooling_func=np.median)
19+
agglo_mean.fit(X)
20+
agglo_median.fit(X)
21+
assert_true(np.size(np.unique(agglo_mean.labels_)) == n_clusters)
22+
assert_true(np.size(np.unique(agglo_median.labels_)) == n_clusters)
23+
assert_true(np.size(agglo_mean.labels_) == X.shape[1])
24+
assert_true(np.size(agglo_median.labels_) == X.shape[1])
25+
26+
# Test transform
27+
Xt_mean = agglo_mean.transform(X)
28+
Xt_median = agglo_median.transform(X)
29+
assert_true(Xt_mean.shape[1] == n_clusters)
30+
assert_true(Xt_median.shape[1] == n_clusters)
31+
assert_true(Xt_mean == np.array([1 / 3.]))
32+
assert_true(Xt_median == np.array([0.]))
33+
34+
# Test inverse transform
35+
X_full_mean = agglo_mean.inverse_transform(Xt_mean)
36+
X_full_median = agglo_median.inverse_transform(Xt_median)
37+
assert_true(np.unique(X_full_mean[0]).size == n_clusters)
38+
assert_true(np.unique(X_full_median[0]).size == n_clusters)
39+
40+
assert_array_almost_equal(agglo_mean.transform(X_full_mean),
41+
Xt_mean)
42+
assert_array_almost_equal(agglo_median.transform(X_full_median),
43+
Xt_median)

0 commit comments

Comments
 (0)
0