8000 [CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705) · pytorch/pytorch@5228986 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5228986

Browse files
malfetpytorchmergebot
authored andcommitted
[CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705)
By addressing a feedback requested at #145746 Pull Request resolved: #150705 Approved by: https://github.com/atalman
1 parent e9e5682 commit 5228986

File tree

4 files changed

+23
-11
lines changed

4 files changed

+23
-11
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
7878
}
7979
}
8080

81-
#ifdef USE_ROCM
81+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080)
8282
template <int io_sizes>
8383
constexpr auto elems_per_thread(){
8484
if constexpr (io_sizes == 1) {
@@ -219,7 +219,7 @@ static inline void launch_vectorized_kernel(
219219
constexpr auto io_size = calc_io_size<func_t>();
220220
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
221221
auto stream = at::cuda::getCurrentCUDAStream();
222-
#ifdef USE_ROCM
222+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080)
223223
int vec_size = memory::can_vectorize_up_to<func_t>(data);
224224
#else
225225
using cpp_type = typename function_traits<func_t>::result_type;
@@ -241,11 +241,13 @@ static inline void launch_vectorized_kernel(
241241
C10_CUDA_KERNEL_LAUNCH_CHECK();
242242
break;
243243
#endif
244+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
244245
case 8:
245246
vectorized_elementwise_kernel<8, func_t, array_t>
246247
<<<grid, num_threads(), 0, stream>>>(N, f, data);
247248
C10_CUDA_KERNEL_LAUNCH_CHECK();
248249
break;
250+
#endif
249251
case 4:
250252
vectorized_elementwise_kernel<4, func_t, array_t>
251253
<<<grid, num_threads(), 0, stream>>>(N, f, data);

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,9 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
486486
uint64_t address = reinterpret_cast<uint64_t>(pointer);
487487
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
488488
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
489+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
489490
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
491+
#endif
490492
#ifdef USE_ROCM
491493
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
492494
constexpr int type_size = sizeof(scalar_t);
@@ -495,7 +497,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
495497
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
496498
return 8;
497499
} else
498-
#else
500+
#elif defined(CUDA_VERSION) && CUDA_VERSION >= 12080
499501
if (address % vec8_alignment == 0) {
500502
return 8;
501503
} else

aten/src/ATen/native/cuda/thread_constants.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ constexpr int thread_work_size() { return 4; }
1818
constexpr uint32_t num_threads() {
1919
return C10_WARP_SIZE * 4;
2020
}
21-
21+
#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
22+
constexpr int thread_work_size() { return 4; }
23+
#else
2224
constexpr int thread_work_size() { return 8; }
2325
#endif
26+
#endif
2427

2528
constexpr int block_work_size() { return thread_work_size() * num_threads(); }

aten/src/ATen/test/cuda_vectorized_test.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,17 @@ TEST(TestLoops, HasSameArgTypes) {
4646

4747
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
4848
char *ptr = reinterpret_cast<char *>(buffer1);
49+
#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
50+
constexpr auto vectorize_limit = 4;
51+
#else
52+
constexpr auto vectorize_limit= 8;
53+
#endif
4954

50-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
51-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
52-
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
53-
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
54-
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);
55+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), vectorize_limit);
56+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), vectorize_limit);
57+
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), vectorize_limit);
58+
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), vectorize_limit);
59+
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), vectorize_limit);
5560

5661
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
5762
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
@@ -65,8 +70,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
6570
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
6671
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);
6772

68-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
69-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
73+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), vectorize_limit);
74+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), vectorize_limit);
7075
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
7176
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
7277
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);

0 commit comments

Comments
 (0)
0