-
Notifications
You must be signed in to change notification settings - Fork 24.8k
[ATen][CUDA] Implement 128 bit vectorization v2 #145746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) { | |
} | ||
} | ||
|
||
#ifdef USE_ROCM | ||
template <int io_sizes> | ||
constexpr auto elems_per_thread(){ | ||
if constexpr (io_sizes == 1) { | ||
|
@@ -71,6 +72,16 @@ constexpr auto elems_per_thread(){ | |
return 4; | ||
} | ||
} | ||
#else | ||
template <int io_sizes> | ||
constexpr auto elems_per_thread(){ | ||
if constexpr (io_sizes == 1) { | ||
return 16; | ||
} else { | ||
return 8; | ||
} | ||
} | ||
#endif | ||
|
||
template <int io_sizes> | ||
constexpr auto io_block_work_size() { | ||
|
@@ -191,21 +202,33 @@ static inline void launch_vectorized_kernel( | |
constexpr auto io_size = calc_io_size<func_t>(); | ||
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>(); | ||
auto stream = at::cuda::getCurrentCUDAStream(); | ||
#ifdef USE_ROCM | ||
int vec_size = memory::can_vectorize_up_to<func_t>(data); | ||
|
||
#else | ||
using cpp_type = typename function_traits<func_t>::result_type; | ||
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data); | ||
uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type)); | ||
vec_size = std::min<uint16_t>(vec_size, max_vec_size); | ||
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC | ||
// that causes some numerical mismatches with uint8 on sm80 and sm90. | ||
// TODO: Revisit this after CUDA 12.8 update. | ||
if constexpr (sizeof(cpp_type) < 2) { | ||
vec_size = std::min<uint16_t>(vec_size, 4); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are you setting max vec size to 4 here for 1 byte datatypes? Is it to workaround that bug? Can you leave a comment then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. This is a workaround that bug. I have left a comment that explains it. |
||
} | ||
#endif | ||
switch (vec_size) { | ||
#ifdef USE_ROCM | ||
case 16: | ||
vectorized_elementwise_kernel<16, func_t, array_t> | ||
<<<grid, num_threads(), 0, stream>>>(N, f, data); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
break; | ||
#endif | ||
case 8: | ||
vectorized_elementwise_kernel<8, func_t, array_t> | ||
<<<grid, num_threads(), 0, stream>>>(N, f, data); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
break; | ||
#endif | ||
case 4: | ||
vectorized_elementwise_kernel<4, func_t, array_t> | ||
<<<grid, num_threads(), 0, stream>>>(N, f, data); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is elems_per_thread = 8 allaround better than 4 we mostly used previously?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I observed little to no difference. The biggest improvement come from
vec8
.