8000 TST activate common tests for TSNE (#25374) · scikit-learn/scikit-learn@ab7e3d1 · GitHub
[go: up one dir, main page]

Skip to content

Commit ab7e3d1

Browse files
authored
TST activate common tests for TSNE (#25374)
1 parent 263b428 commit ab7e3d1

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

sklearn/manifold/_t_sne.py

< 8000 div class="d-flex flex-row flex-justify-end flex-1 flex-order-1 flex-sm-order-2 flex-items-center">
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from scipy.sparse import csr_matrix, issparse
1818
from numbers import Integral, Real
1919
from ..neighbors import NearestNeighbors
20-
from ..base import BaseEstimator
20+
from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
2121
from ..utils import check_random_state
2222
from ..utils._openmp_helpers import _openmp_effective_n_threads
2323
from ..utils.validation import check_non_negative
@@ -537,7 +537,7 @@ def trustworthiness(X, X_embedded, *, n_neighbors=5, metric="euclidean"):
537537
return t
538538

539539

540-
class TSNE(BaseEstimator):
540+
class TSNE(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
541541
"""T-distributed Stochastic Neighbor Embedding.
542542
543543
t-SNE [1] is a tool to visualize high-dimensional data. It converts
@@ -1145,5 +1145,10 @@ def fit(self, X, y=None):
11451145
self.fit_transform(X)
11461146
return self
11471147

1148+
@property
1149+
def _n_features_out(self):
1150+
"""Number of transformed output features."""
1151+
return self.embedding_.shape[1]
1152+
11481153
def _more_tags(self):
11491154
return {"pairwise": self.metric == "precomputed"}

sklearn/utils/estimator_checks.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4234,9 +4234,14 @@ def fit_then_transform(est):
42344234
def fit_transform(est):
42354235
return est.fit_transform(X, y)
42364236

4237-
transform_methods = [fit_then_transform, fit_transform]
4238-
for transform_method in transform_methods:
4237+
transform_methods = {
4238+
"transform": fit_then_transform,
4239+
"fit_transform": fit_transform,
4240+
}
4241+
for name, transform_method in transform_methods.items():
42394242
transformer = clone(transformer)
4243+
if not hasattr(transformer, name):
4244+
continue
42404245
X_trans_no_setting = transform_method(transformer)
42414246

42424247
# Auto wrapping only wraps the first array
@@ -4269,29 +4274,31 @@ def _output_from_fit_transform(transformer, name, X, df, y):
42694274
("fit.transform/array/df", X, df),
42704275
("fit.transform/array/array", X, X),
42714276
]
4272-
for (
4273-
case,
4274-
data_fit,
4275-
data_transform,
4276-
) in cases:
4277-
transformer.fit(data_fit, y)
4278-
if name in CROSS_DECOMPOSITION:
4279-
X_trans, _ = transformer.transform(data_transform, y)
4280-
else:
4281-
X_trans = transformer.transform(data_transform)
4282-
outputs[case] = (X_trans, transformer.get_feature_names_out())
4277+
if all(hasattr(transformer, meth) for meth in ["fit", "transform"]):
4278+
for (
4279+
case,
4280+
data_fit,
4281+
data_transform,
4282+
) in cases:
4283+
transformer.fit(data_fit, y)
4284+
if name in CROSS_DECOMPOSITION:
4285+
X_trans, _ = transformer.transform(data_transform, y)
4286+
else:
4287+
X_trans = transformer.transform(data_transform)
4288+
outputs[case] = (X_trans, transformer.get_feature_names_out())
42834289

42844290
# fit_transform case:
42854291
cases = [
42864292
("fit_transform/df", df),
42874293
("fit_transform/array", X),
42884294
]
4289-
for case, data in cases:
4290-
if name in CROSS_DECOMPOSITION:
4291-
X_trans, _ = transformer.fit_transform(data, y)
4292-
else:
4293-
X_trans = transformer.fit_transform(data, y)
4294-
outputs[case] = (X_trans, transformer.get_feature_names_out())
4295+
if hasattr(transformer, "fit_transform"):
4296+
for case, data in cases:
4297+
if name in CROSS_DECOMPOSITION:
4298+
X_trans, _ = transformer.fit_transform(data, y)
4299+
else:
4300+
X_trans = transformer.fit_transform(data, y)
4301+
outputs[case] = (X_trans, transformer.get_feature_names_out())
42954302

42964303
return outputs
42974304

0 commit comments

Comments
 (0)
0