10000 ENH Add dtype preservation to FeatureAgglomeration by IvanLauLinTiong · Pull Request #24346 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Add dtype preservation to FeatureAgglomeration #24346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
0e9ff56
added dtype preservations and relevent test for FeatureAgglomeration
IvanLauLinTiong Sep 3, 2022
c86bf03
updated changelog
IvanLauLinTiong Sep 3, 2022
0b6e5bc
Merge branch 'main' into preserve_dtype_feature_agglomeration
IvanLauLinTiong Sep 3, 2022
ef23c49
removed newline
IvanLauLinTiong Sep 3, 2022
4ea4488
nitpicks + don't duplicate test
jeremiedbb Sep 5, 2022
7667a18
Merge remote-tracking branch 'upstream/main' into pr/IvanLauLinTiong/…
jeremiedbb Sep 5, 2022
e1d7ba6
increase number of features for more stability
jeremiedbb Sep 5, 2022
1e53bd1
[all random seeds]
jeremiedbb Sep 5, 2022
d3fe8c4
[all random seeds]
jeremiedbb Sep 5, 2022
cace8c9
lint [all random seeds]
jeremiedbb Sep 5, 2022
d8288a1
Merge branch 'scikit-learn:main' into preserve_dtype_feature_agglomer…
IvanLauLinTiong Sep 5, 2022
384bbaa
Merge branch 'main' into preserve_dtype_feature_agglomeration
IvanLauLinTiong Sep 11, 2022
a841e40
Merge branch 'main' into preserve_dtype_feature_agglomeration
IvanLauLinTiong Sep 19, 2022
3e80bf0
Merge branch 'main' into preserve_dtype_feature_agglomeration
IvanLauLinTiong Sep 28, 2022
4617c98
Merge branch 'main' into preserve_dtype_feature_agglomeration
IvanLauLinTiong Oct 5, 2022
228db6c
Merge branch 'main' into preserve_dtype_feature_agglomeration
IvanLauLinTiong Oct 11, 2022
cf2066d
Merge branch 'main' into preserve_dtype_feature_agglomeration
IvanLauLinTiong Oct 13, 2022
928e6b9
Merge branch 'main' into preserve_dtype_feature_agglomeration
IvanLauLinTiong Oct 14, 2022
687b2cb
Merge branch 'main' into preserve_dtype_feature_agglomeration
IvanLauLinTiong Oct 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ Changelog
:class:`cluster.AgglomerativeClustering` and will be renamed to `metric` in v1.4.
:pr:`23470` by :user:`Meekail Zain <micky774>`.

- |Enhancement| :class:`cluster.FeatureAgglomeration` preserves dtype for
`numpy.float32`. :pr:`24346` by :user:`LinTiong Lau <IvanLauLinTiong>`.

- |Fix| :class:`cluster.KMeans` now supports readonly attributes when predicting.
:pr:`24258` by `Thomas Fan`_

Expand Down
7 changes: 6 additions & 1 deletion sklearn/cluster/_agglomerative.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,9 @@ def fit(self, X, y=None):
Returns the transformer.
"""
self._validate_params()
X = self._validate_data(X, ensure_min_features=2)
X = self._validate_data(
X, ensure_min_features=2, dtype=[np.float64, np.float32]
)
super()._fit(X.T)
self._n_features_out = self.n_clusters_
return self
Expand All @@ -1331,3 +1333,6 @@ def fit(self, X, y=None):
def fit_predict(self):
"""Fit and return the result of each sample's clustering assignment."""
raise AttributeError

def _more_tags(self):
return {"preserves_dtype": [np.float64, np.float32]}
7 changes: 4 additions & 3 deletions sklearn/cluster/_feature_agglomeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,21 @@ def transform(self, X):
"""
check_is_fitted(self)

X = self._validate_data(X, reset=False)
X = self._validate_data(X, reset=False, dtype=[np.float64, np.float32])
if self.pooling_func == np.mean and not issparse(X):
size = np.bincount(self.labels_)
n_samples = X.shape[0]
# a fast way to compute the mean of grouped features
nX = np.array(
[np.bincount(self.labels_, X[i, :]) / size for i in range(n_samples)]
[np.bincount(self.labels_, X[i, :]) / size for i in range(n_samples)],
dtype=X.dtype,
)
else:
nX = [
self.pooling_func(X[:, self.labels_ == l], axis=1)
for l in np.unique(self.labels_)
]
nX = np.array(nX).T
nX = np.array(nX, dtype=X.dtype).T
return nX

def inverse_transform(self, Xred):
Expand Down
17 changes: 16 additions & 1 deletion sklearn/cluster/tests/test_feature_agglomeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from numpy.testing import assert_array_equal
from sklearn.cluster import FeatureAgglomeration
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_allclose, assert_array_almost_equal
from sklearn.datasets import make_blobs


Expand Down Expand Up @@ -53,3 +53,18 @@ def test_feature_agglomeration_feature_names_out():
assert_array_equal(
[f"featureagglomeration{i}" for i in range(n_clusters)], names_out
)


def test_feature_agglomeration_numerical_consistency(global_random_seed):
"""Ensure numerical consistency among np.float32 and np.float64"""
rng = np.random.RandomState(global_random_seed)
X_64, _ = make_blobs(n_features=12, random_state=rng)
X_32 = X_64.astype(np.float32)

agglo_32 = FeatureAgglomeration(n_clusters=3)
agglo_64 = FeatureAgglomeration(n_clusters=3)

X_trans_64 = agglo_64.fit_transform(X_64)
X_trans_32 = agglo_32.fit_transform(X_32)

assert_allclose(X_trans_32, X_trans_64)
0