8000 ENH Adds get_feature_names_out to neighbors module (#22212) · scikit-learn/scikit-learn@330881a · GitHub
[go: up one dir, main page]

Skip to content

Commit 330881a

Browse files
Micky774ogrisel
andauthored
ENH Adds get_feature_names_out to neighbors module (#22212)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 4327784 commit 330881a

File tree

6 files changed

+65
-8
lines changed

6 files changed

+65
-8
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,11 @@ Changelog
439439
ndarray with `np.nan` when passed a `Float32` or `Float64` pandas extension
440440
array with `pd.NA`. :pr:`21278` by `Thomas Fan`_.
441441

442+
- |Enhancement| Adds :term:`get_feature_names_out` to
443+
:class:`neighbors.RadiusNeighborsTransformer`, :class:`neighbors.KNeighborsTransformer`
444+
and :class:`neighbors.NeighborhoodComponentsAnalysis`. :pr:`22212` by
445+
:user : `Meekail Zain <micky774>`.
446+
442447
- |Fix| :class:`neighbors.KernelDensity` now validates input parameters in `fit`
443448
instead of `__init__`. :pr:`21430` by :user:`Desislava Vasileva <DessyVV>` and
444449
:user:`Lucy Jimenez <LucyJimenez>`.

sklearn/neighbors/_graph.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ._base import KNeighborsMixin, RadiusNeighborsMixin
88
from ._base import NeighborsBase
99
from ._unsupervised import NearestNeighbors
10-
from ..base import TransformerMixin
10+
from ..base import TransformerMixin, _ClassNamePrefixFeaturesOutMixin
1111
from ..utils.validation import check_is_fitted
1212

1313

@@ -223,7 +223,9 @@ def radius_neighbors_graph(
223223
return X.radius_neighbors_graph(query, radius, mode)
224224

225225

226-
class KNeighborsTransformer(KNeighborsMixin, TransformerMixin, NeighborsBase):
226+
class KNeighborsTransformer(
227+
_ClassNamePrefixFeaturesOutMixin, KNeighborsMixin, TransformerMixin, NeighborsBase
228+
):
227229
"""Transform X into a (weighted) graph of k nearest neighbors.
228230
229231
The transformed data is a sparse graph as returned by kneighbors_graph.
@@ -389,7 +391,9 @@ def fit(self, X, y=None):
389391
self : KNeighborsTransformer
390392
The fitted k-nearest neighbors transformer.
391393
"""
392-
return self._fit(X)
394+
self._fit(X)
395+
self._n_features_out = self.n_samples_fit_
396+
return self
393397

394398
def transform(self, X):
395399
"""Compute the (weighted) graph of Neighbors for points in X.
@@ -445,7 +449,12 @@ def _more_tags(self):
445449
}
446450

447451

448-
class RadiusNeighborsTransformer(RadiusNeighborsMixin, TransformerMixin, NeighborsBase):
452+
class RadiusNeighborsTransformer(
453+
_ClassNamePrefixFeaturesOutMixin,
454+
RadiusNeighborsMixin,
455+
TransformerMixin,
456+
NeighborsBase,
457+
):
449458
"""Transform X into a (weighted) graph of neighbors nearer than a radius.
450459
451460
The transformed data is a sparse graph as returned by
@@ -614,7 +623,9 @@ def fit(self, X, y=None):
614623
self : RadiusNeighborsTransformer
615624
The fitted radius neighbors transformer.
616625
"""
617-
return self._fit(X)
626+
self._fit(X)
627+
self._n_features_out = self.n_samples_fit_
628+
return self
618629

619630
def transform(self, X):
620631
"""Compute the (weighted) graph of Neighbors for points in X.

sklearn/neighbors/_nca.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from scipy.optimize import minimize
1616
from ..utils.extmath import softmax
1717
from ..metrics import pairwise_distances
18-
from ..base import BaseEstimator, TransformerMixin
18+
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
1919
from ..preprocessing import LabelEncoder
2020
from ..decomposition import PCA
2121
from ..utils.multiclass import check_classification_targets
@@ -24,7 +24,9 @@
2424
from ..exceptions import ConvergenceWarning
2525

2626

27-
class NeighborhoodComponentsAnalysis(TransformerMixin, BaseEstimator):
27+
class NeighborhoodComponentsAnalysis(
28+
_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator
29+
):
2830
"""Neighborhood Components Analysis.
2931
3032
Neighborhood Component Analysis (NCA) is a machine learning algorithm for
@@ -249,6 +251,7 @@ def fit(self, X, y):
249251

250252
# Reshape the solution found by the optimizer
251253
self.components_ = opt_result.x.reshape(-1, X.shape[1])
254+
self._n_features_out = self.components_.shape[1]
252255

253256
# Stop timer
254257
t_train = time.time() - t_train

sklearn/neighbors/tests/test_graph.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
import pytest
23

34
from sklearn.metrics import euclidean_distances
45
from sklearn.neighbors import KNeighborsTransformer, RadiusNeighborsTransformer
56
from sklearn.neighbors._base import _is_sorted_by_data
7+
from sklearn.utils._testing import assert_array_equal
68

79

810
def test_transformer_result():
@@ -77,3 +79,23 @@ def test_explicit_diagonal():
7779
# Using transform on new data should not always have zero diagonal
7880
X2t = nnt.transform(X2)
7981
assert not _has_explicit_diagonal(X2t)
82+
83+
84+
@pytest.mark.parametrize("Klass", [KNeighborsTransformer, RadiusNeighborsTransformer])
85+
def test_graph_feature_names_out(Klass):
86+
"""Check `get_feature_names_out` for transformers defined in `_graph.py`."""
87+
88+
n_samples_fit = 20
89+
n_features = 10
90+
rng = np.random.RandomState(42)
91+
X = rng.randn(n_samples_fit, n_features)
92+
93+
est = Klass().fit(X)
94+
names_out = est.get_feature_names_out()
95+
96+
class_name_lower = Klass.__name__.lower()
97+
expected_names_out = np.array(
98+
[f"{class_name_lower}{i}" for i in range(est.n_samples_fit_)],
99+
dtype=object,
100+
)
101+
assert_array_equal(names_out, expected_names_out)

sklearn/neighbors/tests/test_nca.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,3 +554,20 @@ def test_parameters_valid_types(param, value):
554554
y = iris_target
555555

556556
nca.fit(X, y)
557+
558+
559+
def test_nca_feature_names_out():
560+
"""Check `get_feature_names_out` for `NeighborhoodComponentsAnalysis`."""
561+
562+
X = iris_data
563+
y = iris_target
564+
565+
est = NeighborhoodComponentsAnalysis().fit(X, y)
566+
names_out = est.get_feature_names_out()
567+
568+
class_name_lower = est.__class__.__name__.lower()
569+
expected_names_out = np.array(
570+
[f"{class_name_lower}{i}" for i in range(est.components_.shape[1])],
571+
dtype=object,
572+
)
573+
assert_array_equal(names_out, expected_names_out)

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,6 @@ def test_pandas_column_name_consistency(estimator):
386386
"kernel_approximation",
387387
"preprocessing",
388388
"manifold",
389-
"neighbors",
390389
"neural_network",
391390
]
392391

0 commit comments

Comments
 (0)
0