8000 [ATen][CUDA] Implement 128 bit vectorization v2 by Aidyn-A · Pull Request #145746 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ATen][CUDA] Implement 128 bit vectorization v2 #145746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions aten/src/ATen/native/cuda/CUDAJitLoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ struct JittedVecKernelCache {
at::cuda::jit::NvrtcFunction vec1;
at::cuda::jit::NvrtcFunction vec2;
at::cuda::jit::NvrtcFunction vec4;
#ifdef USE_ROCM
at::cuda::jit::NvrtcFunction vec8;
#ifdef USE_ROCM
at::cuda::jit::NvrtcFunction vec16;
#endif

Expand Down Expand Up @@ -131,18 +131,30 @@ void launch_jitted_vectorized_kernel(
int vec_size = at::cuda::jit::can_vectorize_up_to(
desc, c10::ArrayRef<char*>(data.data(), data.size()));

#ifndef USE_ROCM
const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
const int optimal_vec_size = 16 / static_cast<int>(input_size);
vec_size = std::min<int>(optimal_vec_size, vec_size);
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
// that causes some numerical mismatches with uint8 on sm80 and sm90.
// TODO: Revisit this after CUDA 12.8 update.
if (input_size < 2) {
vec_size = std::min<int>(vec_size, 4);
}
#endif

// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
// fn_ptr is set to the appropriate function based on the vec size and GPU used
at::cuda::jit::NvrtcFunction* fn_ptr = nullptr;

#ifdef USE_ROCM
if (vec_size == 16) {
fn_ptr = &fn_cache.vec16;
} else if (vec_size == 8) {
fn_ptr = &fn_cache.vec8;
} else
#endif
if (vec_size == 4) {
if (vec_size == 8) {
fn_ptr = &fn_cache.vec8;
} else if (vec_size == 4) {
fn_ptr = &fn_cache.vec4;
} else if (vec_size == 2) {
fn_ptr = &fn_cache.vec2;
Expand Down
27 changes: 25 additions & 2 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
}
}

#ifdef USE_ROCM
template <int io_sizes>
constexpr auto elems_per_thread(){
if constexpr (io_sizes == 1) {
Expand All @@ -71,6 +72,16 @@ constexpr auto elems_per_thread(){
return 4;
}
}
#else
template <int io_sizes>
constexpr auto elems_per_thread(){
if constexpr (io_sizes == 1) {
return 16;
} else {
return 8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is elems_per_thread = 8 allaround better than 4 we mostly used previously?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed little to no difference. The biggest improvement come from vec8.

}
}
#endif

template <int io_sizes>
constexpr auto io_block_work_size() {
Expand Down Expand Up @@ -191,21 +202,33 @@ static inline void launch_vectorized_kernel(
constexpr auto io_size = calc_io_size<func_t>();
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
auto stream = at::cuda::getCurrentCUDAStream();
#ifdef USE_ROCM
int vec_size = memory::can_vectorize_up_to<func_t>(data);

#else
using cpp_type = typename function_traits<func_t>::result_type;
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type));
vec_size = std::min<uint16_t>(vec_size, max_vec_size);
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
// that causes some numerical mismatches with uint8 on sm80 and sm90.
// TODO: Revisit this after CUDA 12.8 update.
if constexpr (sizeof(cpp_type) < 2) {
vec_size = std::min<uint16_t>(vec_size, 4);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you setting max vec size to 4 here for 1 byte datatypes? Is it to workaround that bug? Can you leave a comment then?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. This is a workaround that bug. I have left a comment that explains it.

}
#endif
switch (vec_size) {
#ifdef USE_ROCM
case 16:
vectorized_elementwise_kernel<16, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
#endif
case 8:
vectorized_elementwise_kernel<8, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
#endif
case 4:
vectorized_elementwise_kernel<4, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/native/cuda/Dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,11 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
// make sure we don't break assumption that we can't have > 16 elements / thread
TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]");
#else
const int optimal_vec_size = 16 / static_cast<int>(sizeof(scalar_t));
vec_size = std::min<int>(optimal_vec_size, vec_size);

// make sure we don't break assumption that we can't have > 4 elements / thread
TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]");
TORCH_INTERNAL_ASSERT(vec_size <= 8, "Value of VEC must be in [2, 4, 8]");
#endif
}

Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/native/cuda/MemoryAccess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -351,15 +351,19 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
#ifdef USE_ROCM
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
#ifdef USE_ROCM
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
constexpr int type_size = sizeof(scalar_t);
if (type_size == 1 && (address % vec16_alignment == 0)) {
return 16;
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
return 8;
} else
#else
if (address % vec8_alignment == 0) {
return 8;
} else
#endif
if (address % vec4_alignment == 0) {
return 4;
Expand Down
11 changes: 8 additions & 3 deletions aten/src/ATen/native/cuda/jit_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,6 @@ void initializeCudaContext() {
}
}

#ifdef USE_ROCM
int calc_io_size(
const int nInputs,
const int nOutputs,
Expand All @@ -953,7 +952,6 @@ int calc_io_size(

return 0;
}
#endif

int calc_thread_work_size(
const int nInputs,
Expand All @@ -972,7 +970,14 @@ int calc_thread_work_size(
}
return io_size;
#else
return JIT_THREAD_WORK_SIZE;
auto io_size = at::cuda::jit::calc_io_size(nInputs, nOutputs, inputs_type, result_type);
TORCH_INTERNAL_ASSERT(io_size > 0);
if (io_size == 1) {
return 16;
} else {
return 8;
}
return io_size;
#endif
}

Expand Down
10 changes: 8 additions & 2 deletions aten/src/ATen/native/cuda/jit_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ inline int can_vectorize_up_to(size_t default_alignment, void *pointer) {
if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) {
return 8;
}
#else
if (ip % (8 * default_alignment) == 0) {
return 8;
}
#endif
if (ip % (4 * default_alignment) == 0) {
return 4;
Expand Down Expand Up @@ -88,15 +92,17 @@ inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef<char*
}

//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh
#ifdef USE_ROCM
#define JIT_THREAD_WORK_SIZE 4
#else
#define JIT_THREAD_WORK_SIZE 8
#endif

#ifdef USE_ROCM
int calc_io_size(
const int nInputs,
const int nOutputs,
const c10::ScalarType& inputs_type,
const c10::ScalarType& result_type);
#endif

int calc_thread_work_size(
const int nInputs,
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/native/cuda/thread_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
constexpr int num_threads() {
return 256;
}

constexpr int thread_work_size() { return 4; }
#else
constexpr uint32_t num_threads() {
return C10_WARP_SIZE * 4;
}

constexpr int thread_work_size() { return 8; }
#endif

constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
14 changes: 7 additions & 7 deletions aten/src/ATen/test/cuda_vectorized_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ TEST(TestLoops, HasSameArgTypes) {
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
char *ptr = reinterpret_cast<char *>(buffer1);

ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);

ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
Expand All @@ -65,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);

ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 4);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);
Expand Down
Loading
0