8000 [ROCm] Input vectorization in elementwise kernels for tensors with he… · pytorch/pytorch@8bc7bd9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8bc7bd9

Browse files
carlobertolliamd-hhashemi
authored andcommitted
[ROCm] Input vectorization in elementwise kernels for tensors with heterogeneous types (#147527)
This patch exemplifies its use for input tensors with types (float,bfloat16) when functor type is float(float,float). Pull Request resolved: #147527 Approved by: https://github.com/jeffdaily Co-authored-by: Hashem Hashemi <hashem.hashemi@amd.com>
1 parent e8dd58b commit 8bc7bd9

File tree

2 files changed

+388
-24
lines changed

2 files changed

+388
-24
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@
5151

5252
namespace at::native {
5353

54+
#ifdef USE_ROCM
55+
// Custom configuration for vectorized elementwise kernel
56+
// with template instantiation.
57+
namespace vectorized_templated_config {
58+
constexpr int num_threads() {
59+
return 512;
60+
}
61+
62+
constexpr int elems_per_thread() {
63+
return 32;
64+
}
65+
66+
constexpr int block_work_size() {
67+
return elems_per_thread() * num_threads();
68+
}
69+
} // namespace vectorized_templated_config
70+
#endif
5471

5572
template <typename args_t, size_t... Is>
5673
constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
@@ -255,6 +272,139 @@ static inline void launch_vectorized_kernel(
255272
}
256273
}
257274

275+
#ifdef USE_ROCM
276+
template <
277+
int vec_size,
278+
typename func_t,
279+
typename array_t,
280+
typename inp_calc_t,
281+
typename out_calc_t,
282+
typename loader_t,
283+
typename storer_t,
284+
typename OutputType,
285+
typename... InputTypes>
286+
C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads())
287+
__global__ void vectorized_templated_elementwise_kernel(
288+
int N,
289+
func_t f,
290+
array_t data,
291+
inp_calc_t inp_calc,
292+
out_calc_t out_calc,
293+
loader_t loader,
294+
storer_t storer) {
295+
int remaining =
296+
N - vectorized_templated_config::block_work_size() * blockIdx.x;
297+
if (remaining <
298+
vectorized_templated_config::block_work_size()) { // if this block handles
299+
// the reminder,
300+
// just do a naive unrolled loop
301+
auto policy = memory::policies::unroll_base<
302+
vectorized_templated_config::num_threads(),
303+
array_t,
304+
inp_calc_t,
305+
out_calc_t,
306+
loader_t,
307+
storer_t,
308+
vectorized_templated_config::elems_per_thread()>(
309+
data, remaining, inp_calc, out_calc, loader, storer);
310+
elementwise_kernel_helper(f, policy);
311+
} else { // if this block has a full `block_work_size` data to handle, use
312+
// vectorized memory access
313+
elementwise_kernel_helper(
314+
f,
315+
memory::policies::vectorized_templated<
316+
vec_size,
317+
array_t,
318+
vectorized_templated_config::elems_per_thread(),
319+
vectorized_templated_config::num_threads(),
320+
OutputType,
321+
InputTypes...>(data));
322+
}
323+
}
324+
325+
// This function assume trivial 1d and supports template specialization
326+
// to avoid dynamic casting.
327+
// Input vectorization size is based on runtime information, i.e.
328+
// the actual data types of the input and output tensor and cannot
329+
// be determined using the functor type, as in regular non-templated
330+
// vectorized kernels. The caller is in charge of selecting the correct input
331+
// vectorization length.
332+
template <
333+
typename func_t,
334+
typename array_t,
335+
typename inp_calc_t,
336+
typename out_calc_t,
337+
typename loader_t,
338+
typename storer_t,
339+
typename OutputType,
340+
typename... InputTypes>
341+
static inline void launch_vectorized_templated_kernel(
342+
int64_t N,
343+
const func_t& f,
344+
array_t data,
345+
inp_calc_t ic,
346+
out_calc_t oc,
347+
loader_t l,
348+
storer_t s) {
349+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
350+
using traits = function_traits<func_t>;
351+
int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) /
352+
vectorized_templated_config::block_work_size();
353+
auto stream = at::cuda::getCurrentCUDAStream();
354+
int vec_size = memory::can_vectorize_up_to<func_t>(data);
355+
switch (vec_size) {
356+
case 8:
357+
vectorized_templated_elementwise_kernel<
358+
8,
359+
func_t,
360+
array_t,
361+
inp_calc_t,
362+
out_calc_t,
363+
loader_t,
364+
storer_t,
365+
OutputType,
366+
InputTypes...>
367+
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(
368+
N, f, data, ic, oc, l, s);
369+
C10_CUDA_KERNEL_LAUNCH_CHECK();
370+
break;
371+
case 4:
372+
vectorized_templated_elementwise_kernel<
373+
4,
374+
func_t,
375+
array_t,
376+
inp_calc_t,
377+
out_calc_t,
378+
loader_t,
379+
storer_t,
380+
OutputType,
381+
InputTypes...>
382+
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(
383+
N, f, data, ic, oc, l, s);
384+
C10_CUDA_KERNEL_LAUNCH_CHECK();
385+
break;
386+
case 2:
387+
vectorized_templated_elementwise_kernel<
388+
2,
389+
func_t,
390+
array_t,
391+
inp_calc_t,
392+
out_calc_t,
393+
loader_t,
394+
storer_t,
395+
OutputType,
396+
InputTypes...>
397+
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(
398+
N, f, data, ic, oc, l, s);
399+
C10_CUDA_KERNEL_LAUNCH_CHECK();
400+
break;
401+
default:
402+
// vector size 1 is not handled as part of vectorize_templated kernel
403+
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
404+
}
405+
}
406+
#endif
407+
258408
template <
259409
typename func_t,
260410
typename array_t,
@@ -392,6 +542,46 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
392542
});
393543
}
394544

545+
#ifdef USE_ROCM
546+
namespace {
547+
template <typename TupleLike, size_t arity, size_t arg_num = 0>
548+
struct check_types {
549+
constexpr static inline bool check() {
550+
if constexpr (arity != 2)
551+
return false;
552+
if constexpr (arg_num == 0) {
553+
using SelectedType = std::tuple_element_t<arg_num, TupleLike>;
554+
if constexpr (std::is_same_v<float, SelectedType>)
555+
return check_types<TupleLike, arity, arg_num + 1>::check();
556+
} else if constexpr (arg_num == 1) {
557+
using SelectedType2 = std::tuple_element_t<arg_num, TupleLike>;
558+
if constexpr (std::is_same_v<float, SelectedType2>)
559+
return check_types<TupleLike, arity, arg_num + 1>::check();
560+
}
561+
return false;
562+
}
563+
};
564+
565+
// Bottom case: if we got this far, assume correct type matching except
566+
// when there are no arguments (arity == 0).
567+
template <typename TupleLike, size_t arity>
568+
struct check_types<TupleLike, arity, arity> {
569+
constexpr static inline bool check() {
570+
if constexpr (arity != 0)
571+
return true;
572+
return false;
573+
}
574+
};
575+
576+
template <typename TupleLike>
577+
struct check_types<TupleLike, 0, 0> {
578+
constexpr static inline bool check() {
579+
return false;
580+
}
581+
};
582+
} // namespace
583+
#endif
584+
395585
template <typename func_t>
396586
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
397587
if (!needs_dynamic_casting<func_t>::check(iter)) {
@@ -416,6 +606,45 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
416606

417607
if (contiguous) {
418608
#ifdef USE_ROCM
609+
// Attempt to call specialized vectorized elementwise kernel
610+
// that enables interleaving.
611+
using float_map = c10::CppTypeToScalarType<float>;
612+
using bfloat16_map = c10::CppTypeToScalarType<BFloat16>;
613+
if (iter.ninputs() == 2 && iter.input_dtype(0) == float_map::value &&
614+
iter.input_dtype(1) == bfloat16_map::value &&
615+
memory::can_vectorize_up_to<func_t>(data) > 1) {
616+
// constexpr to reduce the amount of kernels (empty) generated for
617+
// vectorized templated elementwise and limit which functors are actually
618+
// applied to the load and store at compile time.
619+
using func_tuple = typename traits::ArgsTuple;
620+
if constexpr (
621+
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
622+
check_types<func_tuple, traits::arity, 0>::check()) {
623+
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
624+
auto output_offset_calculator = TrivialOffsetCalculator<1>();
625+
auto loader = memory::LoadWithCast<traits::arity>(iter);
626+
auto storer = memory::StoreWithCast<1>(iter);
627+
launch_vectorized_templated_kernel<
628+
func_t,
629+
std::array<char*, ntensors>,
630+
decltype(input_offset_calculator),
631+
decltype(output_offset_calculator),
632+
decltype(loader),
633+
decltype(storer),
634+
float,
635+
float,
636+
BFloat16>(
637+
numel,
638+
f,
639+
data,
640+
input_offset_calculator,
641+
output_offset_calculator,
642+
loader,
643+
storer);
644+
return;
645+
}
646+
}
647+
419648
std::array<ScalarType, ntensors> dtypes;
420649
auto inner_strides = iter.get_inner_strides();
421650
std::array<int, ntensors> strides;

0 commit comments

Comments
 (0)
0