8000 TST fit_transform(X)==fit(X).transform(X) · InterferencePattern/scikit-learn@91c9fab · GitHub
[go: up one dir, main page]

Skip to content

Commit 91c9fab

Browse files
hrishikeshiolarsmans
authored andcommitted
TST fit_transform(X)==fit(X).transform(X)
1 parent 4a9ecf0 commit 91c9fab

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

sklearn/tests/test_common.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,25 @@ def test_transformers():
216216
if hasattr(trans, 'transform'):
217217
if Trans in (_PLS, PLSCanonical, PLSRegression, CCA, PLSSVD):
218218
X_pred2 = trans.transform(X, y_)
219+
X_pred3 = trans.fit_transform(X, y=y_)
219220
else:
220221
X_pred2 = trans.transform(X)
222+
X_pred3 = trans.fit_transform(X, y=y_)
221223
if isinstance(X_pred, tuple) and isinstance(X_pred2, tuple):
222-
for x_pred, x_pred2 in zip(X_pred, X_pred2):
224+
for x_pred, x_pred2, x_pred3 in zip(X_pred, X_pred2, X_pred3):
223225
assert_array_almost_equal(
224226
x_pred, x_pred2, 2,
225227
"fit_transform not correct in %s" % Trans)
228+
assert_array_almost_equal(
229+
x_pred3, x_pred2, 2,
230+
"fit_transform not correct in %s" % Trans)
226231
else:
227232
assert_array_almost_equal(
228233
X_pred, X_pred2, 2,
229234
"fit_transform not correct in %s" % Trans)
235+
assert_array_almost_equal(
236+
X_pred3, X_pred2, 2,
237+
"fit_transform not correct in %s" % Trans)
230238

231239
# raises error on malformed input for transform
232240
assert_raises(ValueError, trans.transform, X.T)
@@ -530,7 +538,7 @@ def test_classifiers_classes():
530538
y = 2 * y + 1
531539
classes = np.unique(y)
532540
# TODO: make work with next line :)
533-
#y = y.astype(np.str)
541+
# y = y.astype(np.str)
534542
for name, Clf in classifiers:
535543
if name in dont_test:
536544
continue
@@ -645,7 +653,7 @@ def test_configure():
645653
with warnings.catch_warnings():
646654
# The configuration spits out warnings when not finding
647655
# Blas/Atlas development headers
648-
warnings.simplefilter('ignore', UserWarning)
656+
warnings.simplefilter('ignore', UserWarning)
649657
execfile('setup.py', dict(__name__='__main__'))
650658
finally:
651659
sys.argv = old_argv

0 commit comments

Comments
 (0)
0