8000 Recover non-standard bool test for msort (#139870) · pytorch/pytorch@565a794 · GitHub
[go: up one dir, main page]

Skip to content

Commit 565a794

Browse files
xw285cornellpytorchmergebot
authored andcommitted
Recover non-standard bool test for msort (#139870)
Summary: I was looking into why the non-standard bool value will fail for msort - it makes sense for argsort and sort to fail, because we're randomly generating uint8 so the order will be different (and thus the indices will be different). But msort should work. After some digging, it's interesting that even though scalar_t is bool, when the actual value is a uint8_t, the comparison will treat them as signed. I tried lhs=255 and rhs=0: lhs < rhs is equivalent to -1 < 0 which is true (but it's supposed to be False) Therefore we add an explicit type cast. Test Plan: Remove the test skip Differential Revision: D65472170 Pull Request resolved: #139870 Approved by: https://github.com/Skylion007, https://github.com/davidberard98
1 parent 2f3a5a1 commit 565a794

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

aten/src/ATen/native/cuda/SortingCommon.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,16 @@ inline bool getGridFromTiles(int64_t gridTiles, dim3& grid) {
4646
template <typename scalar_t, bool handleNaN = false>
4747
struct GTOp {
4848
__device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
49-
return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) || (lhs > rhs);
49+
return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) ||
50+
(static_cast<scalar_t>(lhs) > static_cast<scalar_t>(rhs));
5051
}
5152
};
5253

5354
template <typename scalar_t, bool handleNaN = false>
5455
struct LTOp {
5556
__device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
56-
return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) || (lhs < rhs);
57+
return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) ||
58+
(static_cast<scalar_t>(lhs) < static_cast<scalar_t>(rhs));
5759
}
5860
};
5961

torch/testing/_internal/common_methods_invocations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19517,8 +19517,9 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1951719517
supports_fwgrad_bwgrad=True,
1951819518
sample_inputs_func=sample_inputs_msort,
1951919519
skips=(
19520+
# https://github.com/pytorch/pytorch/issues/139972
1952019521
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
19521-
dtypes=[torch.bool], device_type='cuda'),
19522+
dtypes=[torch.bool], device_type='cuda', active_if=TEST_WITH_ROCM),
1952219523
)),
1952319524
OpInfo('movedim',
1952419525
aliases=('moveaxis',),

0 commit comments

Comments
 (0)
0