-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🐛 Bug
TestReductionsCPU.test_nansum_out_dtype_cpu
is flaky, but its flakiness is currently hidden.
Currently, this test iterates on a combination of two supported dtypes (the input dtype & output dtype):
pytorch/test/test_reductions.py
Lines 1112 to 1114 in 1aa14fc
def test_nansum_out_dtype(self, device): | |
dtypes = list(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)) | |
for inp_dtype, out_dtype in combinations(dtypes, 2): |
If one were to print those combinations, one would get a list, some of whose elements are, in order:
[[torch.uint8, torch.int8], [torch.uint8, torch.int16], [torch.uint8, torch.int32],
[torch.uint8, torch.int64], [torch.uint8, torch.float64], [torch.uint8, torch.float32], [torch.uint8, torch.float16]]
However, the test fails if one were to swap [torch.uint8, torch.float16]
with [torch.uint8, torch.float32]
, i.e. something like,
[[torch.uint8, torch.int8], [torch.uint8, torch.int16], [torch.uint8, torch.int32],
[torch.uint8, torch.int64], [torch.uint8, torch.float64], [torch.uint8, torch.float16], [torch.uint8, torch.float32]]
Steps to reproduce the behavior:
- Replace lines 1113 & 1114 of
test_nansum_out_dtype
intest_reductions.py
with
for inp_dtype, out_dtype in [[torch.uint8, torch.int8], [torch.uint8, torch.int16], [torch.uint8, torch.int32],
[torch.uint8, torch.int64], [torch.uint8, torch.float64], [torch.uint8, torch.float16], [torch.uint8, torch.float32]]:
- Run the test with the
-v
option. The test fails with the error:
Traceback (most recent call last):
File "/pytorch/torch/testing/_internal/common_device_type.py", line 292, in instantiated_test
result = test_fn(self, *args)
File "test_reductions.py", line 1121, in test_nansum_out_dtype
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
File "/pytorch/torch/testing/_internal/common_utils.py", line 1145, in compare_with_numpy
self.assertEqual(np_result, torch_result, **kwargs)
File "/pytorch/torch/testing/_internal/common_utils.py", line 1272, in assertEqual
self.assertEqual(x, y.item(), atol=atol, rtol=rtol, msg=msg,
File "/pytorch/torch/testing/_internal/common_utils.py", line 1403, in assertEqual
super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
AssertionError: False is not true : Scalars failed to compare as equal! Comparing 4448.0 and 4444.0 gives a difference of 4.0, but the allowed difference with rtol=1.3e-06 and atol=1e-05 is only 0.0057872!
- The above issue can also be reproduced with the following snippet of Python code-
import torch
import numpy
x = torch.tensor([[66, 63, 21, 84, 86, 90, 86, 21],
[81, 68, 22, 39, 43, 38, 24, 63],
[96, 98, 33, 85, 90, 87, 52, 99],
[56, 21, 95, 59, 85, 51, 16, 37],
[74, 55, 95, 97, 45, 82, 71, 33],
[88, 90, 91, 68, 31, 22, 53, 30],
[68, 22, 73, 82, 36, 56, 96, 20],
[71, 66, 82, 17, 97, 35, 20, 43],
[31, 81, 87, 90, 60, 49, 96, 91]], dtype=torch.uint8)
y = torch.nansum(x, dtype=torch.float16)
z = numpy.nansum(x.detach().cpu().numpy(), dtype=numpy.float16)
# y is 4444.0, but z is 4450.0
- Changing the order further can make this test pass, as can removing some entries from the list.
Expected behavior
This test should pass regardless of the order of the [inp_dtype, out_dtype]
pairs.
Environment
Source of the current master branch.
PyTorch version: 1.10.0a0+git1aa14fc
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.8 (64-bit runtime)
Python platform: Linux-5.4.0-67-generic-x86_64-with-glibc2.29
Versions of relevant libraries:
[pip3] numpy==1.20.3
Additional context
On CUDA, step 3 evaluates to 4448.0 on PyTorch 1.8.1+cu101
on Google Colab, as opposed to 4444.0
by a CPU build based on the current master branch (as well as the v1.8.1 CPU implementation on Google Colab).
numpy evaluates it to 4450.0.