8000 Revert D34152115: [pytorch][PR] [ROCm] Enable sort operator BF16 support · pytorch/pytorch@80f2346 · GitHub
[go: up one dir, main page]

Skip to content

Commit 80f2346

Browse files
malfetpytorchmergebot
authored andcommitted
Revert D34152115: [pytorch][PR] [ROCm] Enable sort operator BF16 support
Test Plan: revert-hammer Differential Revision: D34152115 (aa44480) Original commit changeset: 53841c91976b Original Phabricator Diff: D34152115 (aa44480) fbshipit-source-id: c9b5cc06198032af73cd6390466de2c62576a1e1 (cherry picked from commit eb72533)
1 parent dc169d5 commit 80f2346

File tree

5 files changed

+22
-31
lines changed

5 files changed

+22
-31
lines changed

aten/src/ATen/cuda/cub.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ AT_INSTANTIATE_SORT_PAIRS(int64_t, 4)
5757

5858
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)
5959

60+
// BFloat16 is not supported by ROCm's radix sort
61+
#if !AT_ROCM_ENABLED()
6062
AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8)
63+
#endif
6164

6265
} // namespace detail
6366

aten/src/ATen/cuda/cub.cuh

Lines changed: 5 additions & 24 deletions
55
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,17 @@
4545

4646
#ifdef USE_ROCM
4747
#define NO_ROCM(x)
48-
#define ROCM_HIPCUB(x) ::hipcub
4948
#else
5049
#define NO_ROCM(x) x
51-
#define ROCM_HIPCUB(x) x
5250
#endif
5351

54-
#if !CUB_SUPPORTS_NV_BFLOAT16() || \
-
(defined(USE_ROCM) && ROCM_VERSION >= 40500)
52+
#if !defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()
5653

57-
#if !defined(USE_ROCM)
5854
namespace at_cuda_detail {
59-
#endif
60-
6155
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
6256

6357
template <>
64-
struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16>
58+
struct cub::FpLimits<c10::BFloat16>
6559
{
6660
static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
6761
unsigned short max_word = 0x7F7F;
@@ -74,14 +68,8 @@ struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16>
7468
}
7569
};
7670

77-
template <>
78-
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
79-
ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
80-
81-
#if !defined(USE_ROCM)
82-
} // namespace at_cuda_detail
83-
#endif
84-
71+
template <> struct cub::NumericTraits<c10::BFloat16>: cub::BaseTraits<cub::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
72+
}
8573
#endif
8674

8775
#if !defined(USE_ROCM)
@@ -105,20 +93,13 @@ struct cuda_type<c10::Half> {
10593
using type = __half;
10694
};
10795

108-
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
96+
#if CUB_SUPPORTS_NV_BFLOAT16()
10997

11098
template<>
11199
struct cuda_type<c10::BFloat16> {
112100
using type = __nv_bfloat16;
113101
};
114102

115-
#elif (defined(USE_ROCM) && ROCM_VERSION >= 40500)
116-
117-
template<>
118-
struct cuda_type<c10::BFloat16> {
119-
using type = hip_bfloat16;
120-
};
121-
122103
#endif
123104

124105
} // namespace detail

aten/src/ATen/native/cuda/Sort.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,14 @@ void launch_stable_sort_kernel(
325325
TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort);
326326
int64_t *indices_ptr = indices.data_ptr<int64_t>();
327327

328-
#if (defined(USE_ROCM) && ROCM_VERSION < 40500)
329-
constexpr bool is_rocm_bf16_sort_unsupported = true;
328+
#if defined(USE_ROCM)
329+
constexpr bool is_rocm = true;
330330
#else
331-
constexpr bool is_rocm_bf16_sort_unsupported = false;
331+
constexpr bool is_rocm = false;
332332
#endif
333333

334334
AT_DISPATCH_ALL_TYPES_AND3(kBool, kHalf, kBFloat16, self.scalar_type(), "sort", [&]{
335-
c10::guts::if_constexpr<!(is_rocm_bf16_sort_unsupported && std::is_same<scalar_t, c10::BFloat16>::value)>([&](auto _){
335+
c10::guts::if_constexpr<!(is_rocm && std::is_same<scalar_t, c10::BFloat16>::value)>([&](auto _){
336336
const scalar_t *self_ptr = self.data_ptr<scalar_t>();
337337
scalar_t *values_ptr = values.data_ptr<scalar_t>();
338338
int64_t remaining = _(numel);
@@ -353,7 +353,7 @@ void launch_stable_sort_kernel(
353353
values_ptr += n;
354354
indices_ptr += n;
355355
}
356-
}, [&](auto _){ TORCH_CHECK(_(false), "BFloat16 is not supported on ROCm < 4.5"); });
356+
}, [&](auto _){ TORCH_CHECK(_(false), "BFloat16 is not supported on ROCm"); });
357357
});
358358
}
359359

test/test_sort_and_select.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def test_sort(self, device):
135135
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
136136
@dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128})
137137
def test_stable_sort(self, device, dtype):
138+
if TEST_WITH_ROCM and dtype == torch.bfloat16:
139+
return
138140
sizes = (100, 1000, 10000)
139141
for ncopies in sizes:
140142
x = torch.tensor([0, 1] * ncopies, dtype=dtype, device=device)
@@ -228,6 +230,8 @@ def test_topk_1d_output_discontiguous(self, device, dtype):
228230
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
229231
@dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128})
230232
def test_stable_sort_against_numpy(self, device, dtype):
233+
if TEST_WITH_ROCM and dtype == torch.bfloat16:
234+
return
231235
if dtype in floating_types_and(torch.float16, torch.bfloat16):
232236
inf = float('inf')
233237
neg_inf = -float('inf')
@@ -291,6 +295,9 @@ def repeated_index_fill(t, dim, idxs, vals):
291295

292296
@dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
293297
def test_msort(self, device, dtype):
298+
if TEST_WITH_ROCM and dtype == torch.bfloat16:
299+
return
300+
294301
def test(shape):
295302
tensor = make_tensor(shape, device, dtype, low=-9, high=9)
296303
if tensor.size() != torch.Size([]):

torch/testing/_internal/common_methods_invocations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13285,7 +13285,7 @@ def ref_pairwise_distance(input1, input2):
1328513285
OpInfo('sort',
1328613286
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
1328713287
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
13288-
dtypesIfROCM=all_types_and(torch.float16, torch.bfloat16),
13288+
dtypesIfROCM=all_types_and(torch.float16),
1328913289
sample_inputs_func=sample_inputs_sort,
1329013290
supports_forward_ad=True,
1329113291
supports_fwgrad_bwgrad=True,
@@ -13931,7 +13931,7 @@ def ref_pairwise_distance(input1, input2):
1393113931
OpInfo('msort',
1393213932
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
1393313933
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
13934-
dtypesIfROCM=all_types_and(torch.float16, torch.bfloat16),
13934+
dtypesIfROCM=all_types_and(torch.float16),
1393513935
check_batched_gradgrad=False,
1393613936
supports_forward_ad=True,
1393713937
supports_fwgrad_bwgrad=True,

0 commit comments

Comments
 (0)
0