8000 Fix test to pass on GPU · IBMZ-Linux-OSS-Python/tensorflow@99b1aac · GitHub
[go: up one dir, main page]

Skip to content

Commit 99b1aac

Browse files
Fix test to pass on GPU
PiperOrigin-RevId: 673138979
1 parent 3b53593 commit 99b1aac

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorflow/python/kernel_tests/signal/fft_ops_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def testFftLength_rfftn(self, dims, size, np_rtype):
627627
else:
628628
fft_length = (size * 2, size, size * 2)
629629
axes = (-3, -2, -1)
630-
self._CompareBackward_fftn(
630+
self._CompareForward_fftn(
631631
r2c.astype(np_rtype),
632632
fft_length=fft_length,
633633
axes=axes,
@@ -639,7 +639,7 @@ def testFftLength_rfftn(self, dims, size, np_rtype):
639639
c2r = self._generate_valid_irfft_input(
640640
c2r, np_ctype, r2c, np_rtype, 2, fft_length
641641
)
642-
self._CompareForward_fftn(c2r, fft_length, axes, rtol=tol)
642+
self._CompareBackward_fftn(c2r, fft_length, axes, rtol=tol)
643643

644644
@parameterized.parameters(
645645
itertools.product(range(1, 4), (5, 6), (np.float32, np.float64))

0 commit comments

Comments
 (0)
0