8000 [ROCm] Improvements for vectorized elementwise kernels by jerrymannil · Pull Request #143269 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ROCm] Improvements for vectorized elementwise kernels #143269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
review comments incorporated
  • Loading branch information
jerrymannil committed Dec 16, 2024
commit 9d93f62f744a2d6054b1f10e11fdaa41519bdf15
8 changes: 4 additions & 4 deletions aten/src/ATen/cuda/jiterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
if (!fn_ptr->function) { // cache miss!
// Generates program
auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, name,
common_dtype, toOpMathType(common_dtype), common_dtype,
f_inputs_type_str, compute_type_str, result_type_str,
/*contiguous=*/true, /*dynamic_casting=*/false,
at::cuda::jit::BinaryFuncVariant::NoScalar,
extra_args_types,
tws,
vectorized, vec_size,
tws,
vectorized, vec_size,
return_by_ref);
std::string kernel_name = vectorized ? name + "_vectorized" + std::to_string(vec_size) : name;
// Acquires the program
Expand Down Expand Up @@ -160,7 +160,7 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
const std::lock_guard<std::mutex> lock{_jiterator_mutex};
if (!fn_ptr->function) {
auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, name,
common_dtype, toOpMathType(common_dtype), common_dtype,
f_inputs_type_str, compute_type_str, result_type_str,
contiguous, dynamic_casting,
at::cuda::jit::BinaryFuncVariant::NoScalar,
extra_args_types, tws, /*vectorized*/false, /*vec_size*/0, return_by_ref);
Expand Down
15 changes: 3 additions & 12 deletions aten/src/ATen/native/cuda/CUDAJitLoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ void launch_jitted_vectorized_kernel(
fn_ptr = &fn_cache.vec16;
} else if (vec_size == 8) {
fn_ptr = &fn_cache.vec8;
} else if (vec_size == 4) {
} else
#endif
if (vec_size == 4) {
fn_ptr = &fn_cache.vec4;
} else if (vec_size == 2) {
fn_ptr = &fn_cache.vec2;
Expand All @@ -149,17 +151,6 @@ void launch_jitted_vectorized_kernel(
} else {
TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
}
#else
if (vec_size == 4) {
fn_ptr = &fn_cache.vec4;
} else if (vec_size == 2) {
fn_ptr = &fn_cache.vec2;
} else if (vec_size ==1) {
fn_ptr = &fn_cache.vec1;
} else {
TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
}
#endif

bool vectorized = vec_size > 1;

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>

// Perform the actual computation
#pragma unroll
for (int jj = 0; jj < RAND_SIZE; jj++) {
for (int jj = 0; jj < RAND_SIZE; jj++) {
#pragma unroll
Copy link
Contributor Author
@jerrymannil jerrymannil Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outer loop will only run once for non-ROCM, since RAND_SIZE =1
Inner loop is basically same as old code.

for (int ii = 0; ii < std::min(VEC, 4); ii++) {
r[jj * 4 + ii] = src[jj * 4 + ii]*(&rand[jj].x)[ii]*scale;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/MemoryAccess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
} else if (address % vec2_alignment == 0) {
return 2;
}
#else
#else
if (address % vec4_alignment == 0) {
return 4;
} else if (address % vec2_alignment == 0) {
Expand Down
54 changes: 25 additions & 29 deletions aten/src/ATen/native/cuda/jit_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -995,9 +995,9 @@ std::string generate_code(
desc.nOutputs,
desc.f,
desc.name,
desc.f_inputs_type,
toOpMathType(desc.f_inputs_type),
desc.result_type,
typeName(desc.f_inputs_type),
typeName(toOpMathType(desc.f_inputs_type)),
typeName(desc.result_type),
contiguous,
dynamic_casting,
scalar_pos,
Expand All @@ -1013,9 +1013,9 @@ std::string generate_code(
int nOutputs,
const std::string& func_,
const std::string& name,
const c10::ScalarType& f_inputs_type,
const c10::ScalarType& compute_type,
const c10::ScalarType& result_type,
const std::string& f_inputs_type,
const std::string& compute_type,
const std::string& result_type,
bool contiguous,
bool dynamic_casting,
BinaryFuncVariant scalar_pos,
Expand All @@ -1027,15 +1027,11 @@ std::string generate_code(
std::string func = func_;
at::jit::TemplateEnv env;

const std::string f_inputs_type_str = typeName(f_inputs_type);
const std::string compute_type_str = typeName(compute_type);
const std::string result_type_str = typeName(result_type);

env.s("index_type", "unsigned int");
env.s("nInputs", std::to_string(nInputs));
env.s("nOutputs", std::to_string(nOutputs));
env.s("scalar_type", f_inputs_type_str);
env.s("compute_type", compute_type_str);
env.s("scalar_type", f_inputs_type);
env.s("compute_type", compute_type);
env.s("functor", func);
env.s("name", name);
env.s("cmath_string", get_cmath_string());
Expand All @@ -1060,14 +1056,14 @@ std::string generate_code(
for (int i = 0; i < nInputs; i++) {
// TODO these arrays are potentially of the different types, use function
// traits to determine the types
declare_load_arrays << f_inputs_type_str << " arg" << std::to_string(i)
declare_load_arrays << f_inputs_type << " arg" << std::to_string(i)
<< "[" << std::to_string(thread_work_size) << "];\n";
}
env.s("declare_load_arrays", declare_load_arrays.str());

std::stringstream declare_store_arrays;
for (int i = 0; i < nOutputs; i++) {
declare_store_arrays << result_type_str << " out" << std::to_string(i)
declare_store_arrays << result_type << " out" << std::to_string(i)
<< "[" << std::to_string(thread_work_size) << "];\n";
}
env.s("declare_store_arrays", declare_store_arrays.str());
Expand All @@ -1089,7 +1085,7 @@ std::string generate_code(

std::string call_functor_template;
if (return_by_ref) { // return one or more outputs by reference
bool need_temp_out = (compute_type_str != result_type_str);
bool need_temp_out = (compute_type != result_type);
std::stringstream functor_outs;
if (need_temp_out) {
for (int i = 0; i < nOutputs - 1; i++) {
Expand Down Expand Up @@ -1122,24 +1118,24 @@ std::string generate_code(
}
env.s("call_functor", at::jit::CodeTemplate(call_functor_template).format(env));

if (f_inputs_type_str == "at::Half" || result_type_str == "at::Half" ||
f_inputs_type_str == "std::complex<at::Half>" ||
result_type_str == "std::complex<at::Half>" || dynamic_casting) {
if (f_inputs_type == "at::Half" || result_type == "at::Half" ||
f_inputs_type == "std::complex<at::Half>" ||
result_type == "std::complex<at::Half>" || dynamic_casting) {
// complex<Half> depends on complex<T> and Half dtypes.
env.s("half_string", jiterator_half_support_literal);
} else {
env.s("half_string", "");
}
if (f_inputs_type_str == "at::BFloat16" || result_type_str == "at::BFloat16" || dynamic_casting) {
if (f_inputs_type == "at::BFloat16" || result_type == "at::BFloat16" || dynamic_casting) {
env.s("bfloat16_string", jiterator_bfloat16_support_literal);
} else {
env.s("bfloat16_string", "");
}
// the definition of complex math functions is only needed when the compute type is complex
// but the definition of std::complex is needed for dynamic casting even if the compute type is not complex
if (f_inputs_type_str == "std::complex<float>" || result_type_str == "std::complex<float>" ||
f_inputs_type_str == "std::complex<double>" || result_type_str == "std::complex<double>" ||
f_inputs_type_str == "std::complex<at::Half>" || result_type_str == "std::complex<at::Half>") {
if (f_inputs_type == "std::complex<float>" || result_type == "std::complex<float>" ||
f_inputs_type == "std::complex<double>" || result_type == "std::complex<double>" ||
f_inputs_type == "std::complex<at::Half>" || result_type == "std::complex<at::Half>") {
// complex<Half> depends on complex<T> and Half dtypes.
env.s("traits_string", get_traits_string_but_hiprtc_safe());
env.s("complex_body_string", get_complex_body_string());
Expand All @@ -1158,8 +1154,8 @@ std::string generate_code(
env.s("complex_body_string", "");
env.s("complex_math_string", "");
}
if (f_inputs_type_str == "std::complex<at::Half>" ||
result_type_str == "std::complex<at::Half>" || dynamic_casting) {
if (f_inputs_type == "std::complex<at::Half>" ||
result_type == "std::complex<at::Half>" || dynamic_casting) {
// dynamic_casting requires the definition of all types
// include complex<at::Half>
// Look at the definition of `StoreWithCast` and `LoadWithCast`.
Expand Down Expand Up @@ -1190,7 +1186,7 @@ std::string generate_code(
std::stringstream load_inputs;
for (int i = 0; i < nInputs; i++) {
auto i_string = std::to_string(i);
load_inputs << "arg" << i_string << "[j] = l.load<" << f_inputs_type_str
load_inputs << "arg" << i_string << "[j] = l.load<" << f_inputs_type
<< ">(data[" << std::to_string(i + nOutputs)
<< "], input_offsets[" << i_string << "], " << i_string
<< ");\n";
Expand All @@ -1200,7 +1196,7 @@ std::string generate_code(
std::stringstream store_outputs;
for (int i = 0; i < nOutputs; i++) {
auto i_string = std::to_string(i);
store_outputs << "s.store<" << result_type_str
store_outputs << "s.store<" << result_type
<< ">(out" << i_string << "[j], data[" << i_string
<< "], output_offsets[" << i_string << "], " << i_string
<< ");\n";
Expand All @@ -1215,7 +1211,7 @@ std::string generate_code(

// vectorized case
env.s("vec_size", std::to_string(vec_size));
env.s("result_type", result_type_str);
env.s("result_type", result_type);

std::stringstream vector_inputs;
for (const auto i : c10::irange(nInputs)){
Expand Down Expand Up @@ -1261,15 +1257,15 @@ std::string generate_code(
std::stringstream load_unrolled_inputs;
for (const auto i: c10::irange(nInputs)){
auto i_string = std::to_string(i);
load_unrolled_inputs << "arg" << i_string << "[j] = load<" << f_inputs_type_str
load_unrolled_inputs << "arg" << i_string << "[j] = load<" << f_inputs_type
<< ">(data[" << std::to_string(i + nOutputs) << "], linear_idx);\n";
}
env.s("load_unrolled_inputs", load_unrolled_inputs.str());

std::stringstream store_unrolled_outputs;
for (const auto i : c10::irange(nOutputs)) {
auto i_string = std::to_string(i);
store_unrolled_outputs << "store<" << result_type_str << ">(out" << i_string
store_unrolled_outputs << "store<" << result_type << ">(out" << i_string
<< "[j], data[" << i_string << "], linear_idx);\n";
}
env.s("store_unrolled_outputs", store_unrolled_outputs.str());
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/jit_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ std::string generate_code(
int nOutputs,
const std::string& func,
const std::string& name,
const c10::ScalarType& f_inputs_type,
const c10::ScalarType& compute_type,
const c10::ScalarType& result_type,
const std::string& f_inputs_type,
const std::string& compute_type,
const std::string& result_type,
bool contiguous,
bool dynamic_casting,
BinaryFuncVariant scalar_pos,
Expand Down
Loading
0