8000 [ATen][CUDA] Optimize 128 bit vectorization by pytorchbot · Pull Request #152967 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ATen][CUDA] Optimize 128 bit vectorization #152967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,69 @@ constexpr auto calc_io_size(){
#endif
}

#ifndef USE_ROCM
// To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
// into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
// used on sm_90 and sm_100 exclusively.
template <int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
if constexpr (vec_size == 8) {
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
using traits = function_traits<func_t>;
constexpr auto io_size = calc_io_size<func_t>();
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;

if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
// just do a naive unrolled loop
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
auto policy = memory::policies::unroll<
array_t,
decltype(input_calc),
decltype(output_calc),
memory::LoadWithoutCast,
memory::StoreWithoutCast,
elems_per_thread<io_size>()>(
data, remaining, input_calc, output_calc, loader, storer);
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
elementwise_kernel_helper(
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
}
#endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
} else {
using traits = function_traits<func_t>;
constexpr auto io_size = calc_io_size<func_t>();
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;

if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
// just do a naive unrolled loop
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
auto policy = memory::policies::unroll<
array_t,
decltype(input_calc),
decltype(output_calc),
memory::LoadWithoutCast,
memory::StoreWithoutCast,
elems_per_thread<io_size>()>(
data, remaining, input_calc, output_calc, loader, storer);
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
elementwise_kernel_helper(
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
}
}
}

#else // USE_ROCM
template <int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
Expand All @@ -157,15 +220,12 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
#ifdef USE_ROCM
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
#else
constexpr auto optimal_vec_size = vec_size;
#endif
elementwise_kernel_helper(
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
}
}
#endif // USE_ROCM

template <
typename func_t,
Expand Down Expand Up @@ -212,6 +272,11 @@ static inline void launch_vectorized_kernel(
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
// that causes some numerical mismatches with uint8 on sm80 and sm90.
// TODO: Revisit this after CUDA 12.8 update.
cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index());
const int computeCapability = p->major * 10 + p->minor;
if (computeCapability != 90 && computeCapability != 100) {
vec_size = std::min<uint16_t>(vec_size, 4);
}
if constexpr (sizeof(cpp_type) < 2) {
vec_size = std::min<uint16_t>(vec_size, 4);
}
Expand Down
Loading
0