8000 [CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705) · pytorch/pytorch@646116e · GitHub
[go: up one dir, main page]

Skip to content

Commit 646116e

Browse files
malfetpytorchbot
authored andcommitted
[CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705)
By addressing a feedback requested at #145746 Pull Request resolved: #150705 Approved by: https://github.com/atalman (cherry picked from commit 5228986)
1 parent 35f1e76 commit 646116e

File tree

4 files changed

+23
-11
lines changed

4 files changed

+23
-11
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
6161
}
6262
}
6363

64-
#ifdef USE_ROCM
64+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080)
6565
template <int io_sizes>
6666
constexpr auto elems_per_thread(){
6767
if constexpr (io_sizes == 1) {
@@ -202,7 +202,7 @@ static inline void launch_vectorized_kernel(
202202
constexpr auto io_size = calc_io_size<func_t>();
203203
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
204204
auto stream = at::cuda::getCurrentCUDAStream();
205-
#ifdef USE_ROCM
205+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080)
206206
int vec_size = memory::can_vectorize_up_to<func_t>(data);
207207
#else
208208
using cpp_type = typename function_traits<func_t>::result_type;
@@ -224,11 +224,13 @@ static inline void launch_vectorized_kernel(
224224
C10_CUDA_KERNEL_LAUNCH_CHECK();
225225
break;
226226
#endif
227+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
227228
case 8:
228229
vectorized_elementwise_kernel<8, func_t, array_t>
229230
<<<grid, num_threads(), 0, stream>>>(N, f, data);
230231
C10_CUDA_KERNEL_LAUNCH_CHECK();
231232
break;
233+
#endif
232234
case 4:
233235
vectorized_elementwise_kernel<4, func_t, array_t>
234236
<<<grid, num_threads(), 0, stream>>>(N, f, data);

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,9 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
351351
uint64_t address = reinterpret_cast<uint64_t>(pointer);
352352
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
353353
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
354+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
354355
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
356+
#endif
355357
#ifdef USE_ROCM
356358
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
357359
constexpr int type_size = sizeof(scalar_t);
@@ -360,7 +362,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
360362
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
361363
return 8;
362364
} else
363-
#else
365+
#elif defined(CUDA_VERSION) && CUDA_VERSION >= 12080
364366
if (address % vec8_alignment == 0) {
365367
return 8;
366368
} else

aten/src/ATen/native/cuda/thread_constants.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ constexpr int thread_work_size() { return 4; }
1818
constexpr uint32_t num_threads() {
1919
return C10_WARP_SIZE * 4;
2020
}
21-
21+
#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
22+
constexpr int thread_work_size() { return 4; }
23+
#else
2224
constexpr int thread_work_size() { return 8; }
2325
#endif
26+
#endif
2427

2528
constexpr int block_work_size() { return thread_work_size() * num_threads(); }

aten/src/ATen/test/cuda_vectorized_test.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,17 @@ TEST(TestLoops, HasSameArgTypes) {
4646

4747
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
4848
char *ptr = reinterpret_cast<char *>(buffer1);
49+
#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
50+
constexpr auto vectorize_limit = 4;
51+
#else
52+
constexpr auto vectorize_limit= 8;
53+
#endif
4954

50-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
51-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
52-
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
53-
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
54-
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);
55+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), vectorize_limit);
56+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), vectorize_limit);
57+
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), vectorize_limit);
58+
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), vectorize_limit);
59+
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), vectorize_limit);
5560

5661
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
5762
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
@@ -65,8 +70,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
6570
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
6671
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);
6772

68-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
69-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
73+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), vectorize_limit);
74+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), vectorize_limit);
7075
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
7176
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
7277
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);

0 commit comments

Comments
 (0)
0