8000 [ROCm] Fix sort for non-standard bool (#147459) · pytorch/pytorch@703176e · GitHub
[go: up one dir, main page]

Skip to content

Commit 703176e

Browse files
praguptapytorchmergebot
authored andcommitted
[ROCm] Fix sort for non-standard bool (#147459)
When converting from uint8 to bool using `view` op, we get a bool that has 0 for false and a non-zero value for true. However, these kinds of bool have undefined behavior. We only read the last bit as 0 or 1 to convert to false or true. In this fix, we convert bools to uint8, which will convert false to 0 and non-zero value to 1. Essentially, converting non-standard bool to a standard bool and fixing the sort op for non-standard bool. Fixes #139972 Pull Request resolved: #147459 Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony
1 parent 690fc2c commit 703176e

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

aten/src/ATen/native/cuda/Sort.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ void sort_cuda_kernel(
6565
const auto self_dtype = self.dtype();
6666
TORCH_CHECK(self_dtype != ScalarType::ComplexFloat && self_dtype != ScalarType::ComplexDouble,
6767
"Sort currently does not support complex dtypes on CUDA.");
68+
#if defined(USE_ROCM)
69+
// ROCm has undefined behavior for non-standard bools. Here we are converting bool to uint8 which will
70+
// convert false to 0 and true or any non-zero value to a 1. copy_ on const Tensors only changes the
71+
// data in the tensor and not the metadata.
72+
// That's why, tensor's dtype stays as bool. It just becomes a standard bool.
73+
if (self_dtype == ScalarType::Bool) {
74+
self.copy_(self.to(at::kByte));
75+
}
76+
#endif
6877

6978
// use inplace algorithm for smaller input sizes without stable=True
7079
if (should_use_small_sort(self, dim)) {

torch/testing/_internal/common_methods_invocations.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18498,7 +18498,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1849818498
supports_fwgrad_bwgrad=True,
1849918499
skips=(
1850018500
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
18501-
dtypes=[torch.bool], device_type='cuda'),
18501+
dtypes=[torch.bool], device_type='cuda', active_if=not TEST_WITH_ROCM),
1850218502
)),
1850318503
OpInfo('unique',
1850418504
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64),
@@ -19549,12 +19549,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1954919549
check_batched_gradgrad=False,
1955019550
supports_forward_ad=True,
1955119551
supports_fwgrad_bwgrad=True,
19552-
sample_inputs_func=sample_inputs_msort,
19553-
skips=(
19554-
# https://github.com/pytorch/pytorch/issues/139972
19555-
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
19556-
dtypes=[torch.bool], device_type='cuda', active_if=TEST_WITH_ROCM),
19557-
)),
19552+
sample_inputs_func=sample_inputs_msort),
1955819553
OpInfo('movedim',
1955919554
aliases=('moveaxis',),
1956019555
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
@@ -21380,6 +21375,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
2138021375
"test_non_standard_bool_values",
2138121376
dtypes=[torch.bool],
2138221377
device_type='cuda',
21378+
active_if=not TEST_WITH_ROCM
2138321379
),
2138421380
),
2138521381
),

0 commit comments

Comments
 (0)
0