51
51
52
52
namespace at ::native {
53
53
54
+ #ifdef USE_ROCM
55
+ // Custom configuration for vectorized elementwise kernel
56
+ // with template instantiation.
57
+ namespace vectorized_templated_config {
58
+ constexpr int num_threads () {
59
+ return 512 ;
60
+ }
61
+
62
+ constexpr int elems_per_thread () {
63
+ return 32 ;
64
+ }
65
+
66
+ constexpr int block_work_size () {
67
+ return elems_per_thread () * num_threads ();
68
+ }
69
+ } // namespace vectorized_templated_config
70
+ #endif
54
71
55
72
template <typename args_t , size_t ... Is>
56
73
constexpr auto sum_of_sizes (args_t args, std::index_sequence<Is...>) {
@@ -255,6 +272,139 @@ static inline void launch_vectorized_kernel(
255
272
}
256
273
}
257
274
275
+ #ifdef USE_ROCM
276
+ template <
277
+ int vec_size,
278
+ typename func_t ,
279
+ typename array_t ,
280
+ typename inp_calc_t ,
281
+ typename out_calc_t ,
282
+ typename loader_t ,
283
+ typename storer_t ,
284
+ typename OutputType,
285
+ typename ... InputTypes>
286
+ C10_LAUNCH_BOUNDS_1 (vectorized_templated_config::num_threads())
287
+ __global__ void vectorized_templated_elementwise_kernel(
288
+ int N,
289
+ func_t f,
290
+ array_t data,
291
+ inp_calc_t inp_calc,
292
+ out_calc_t out_calc,
293
+ loader_t loader,
294
+ storer_t storer) {
295
+ int remaining =
296
+ N - vectorized_templated_config::block_work_size () * blockIdx .x ;
297
+ if (remaining <
298
+ vectorized_templated_config::block_work_size ()) { // if this block handles
299
+ // the reminder,
300
+ // just do a naive unrolled loop
301
+ auto policy = memory::policies::unroll_base<
302
+ vectorized_templated_config::num_threads (),
303
+ array_t ,
304
+ inp_calc_t ,
305
+ out_calc_t ,
306
+ loader_t ,
307
+ storer_t ,
308
+ vectorized_templated_config::elems_per_thread ()>(
309
+ data, remaining, inp_calc, out_calc, loader, storer);
310
+ elementwise_kernel_helper (f, policy);
311
+ } else { // if this block has a full `block_work_size` data to handle, use
312
+ // vectorized memory access
313
+ elementwise_kernel_helper (
314
+ f,
315
+ memory::policies::vectorized_templated<
316
+ vec_size,
317
+ array_t ,
318
+ vectorized_templated_config::elems_per_thread (),
319
+ vectorized_templated_config::num_threads (),
320
+ OutputType,
321
+ InputTypes...>(data));
322
+ }
323
+ }
324
+
325
+ // This function assume trivial 1d and supports template specialization
326
+ // to avoid dynamic casting.
327
+ // Input vectorization size is based on runtime information, i.e.
328
+ // the actual data types of the input and output tensor and cannot
329
+ // be determined using the functor type, as in regular non-templated
330
+ // vectorized kernels. The caller is in charge of selecting the correct input
331
+ // vectorization length.
332
+ template <
333
+ typename func_t ,
334
+ typename array_t ,
335
+ typename inp_calc_t ,
336
+ typename out_calc_t ,
337
+ typename loader_t ,
338
+ typename storer_t ,
339
+ typename OutputType,
340
+ typename ... InputTypes>
341
+ static inline void launch_vectorized_templated_kernel (
342
+ int64_t N,
343
+ const func_t & f,
344
+ array_t data,
345
+ inp_calc_t ic,
346
+ out_calc_t oc,
347
+ loader_t l,
348
+ storer_t s) {
349
+ TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
350
+ using traits = function_traits<func_t >;
351
+ int64_t grid = (N + vectorized_templated_config::block_work_size () - 1 ) /
352
+ vectorized_templated_config::block_work_size ();
353
+ auto stream = at::cuda::getCurrentCUDAStream ();
354
+ int vec_size = memory::can_vectorize_up_to<func_t >(data);
355
+ switch (vec_size) {
356
+ case 8 :
357
+ vectorized_templated_elementwise_kernel<
358
+ 8 ,
359
+ func_t ,
360
+ array_t ,
361
+ inp_calc_t ,
362
+ out_calc_t ,
363
+ loader_t ,
364
+ storer_t ,
365
+ OutputType,
366
+ InputTypes...>
367
+ <<<grid, vectorized_templated_config::num_threads(), 0 , stream>>> (
368
+ N, f, data, ic, oc, l, s);
369
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
370
+ break ;
371
+ case 4 :
372
+ vectorized_templated_elementwise_kernel<
373
+ 4 ,
374
+ func_t ,
375
+ array_t ,
376
+ inp_calc_t ,
377
+ out_calc_t ,
378
+ loader_t ,
379
+ storer_t ,
380
+ OutputType,
381
+ InputTypes...>
382
+ <<<grid, vectorized_templated_config::num_threads(), 0 , stream>>> (
383
+ N, f, data, ic, oc, l, s);
384
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
385
+ break ;
386
+ case 2 :
387
+ vectorized_templated_elementwise_kernel<
388
+ 2 ,
389
+ func_t ,
390
+ array_t ,
391
+ inp_calc_t ,
392
+ out_calc_t ,
393
+ loader_t ,
394
+ storer_t ,
395
+ OutputType,
396
+ InputTypes...>
397
+ <<<grid, vectorized_templated_config::num_threads(), 0 , stream>>> (
398
+ N, f, data, ic, oc, l, s);
399
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
400
+ break ;
401
+ default :
402
+ // vector size 1 is not handled as part of vectorize_templated kernel
403
+ TORCH_INTERNAL_ASSERT (false , " Unexpected vectorization size" );
404
+ }
405
+ }
406
+ #endif
407
+
258
408
template <
259
409
typename func_t ,
260
410
typename array_t ,
@@ -392,6 +542,46 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
392
542
});
393
543
}
394
544
545
+ #ifdef USE_ROCM
546
+ namespace {
547
+ template <typename TupleLike, size_t arity, size_t arg_num = 0 >
548
+ struct check_types {
549
+ constexpr static inline bool check () {
550
+ if constexpr (arity != 2 )
551
+ return false ;
552
+ if constexpr (arg_num == 0 ) {
553
+ using SelectedType = std::tuple_element_t <arg_num, TupleLike>;
554
+ if constexpr (std::is_same_v<float , SelectedType>)
555
+ return check_types<TupleLike, arity, arg_num + 1 >::check ();
556
+ } else if constexpr (arg_num == 1 ) {
557
+ using SelectedType2 = std::tuple_element_t <arg_num, TupleLike>;
558
+ if constexpr (std::is_same_v<float , SelectedType2>)
559
+ return check_types<TupleLike, arity, arg_num + 1 >::check ();
560
+ }
561
+ return false ;
562
+ }
563
+ };
564
+
565
+ // Bottom case: if we got this far, assume correct type matching except
566
+ // when there are no arguments (arity == 0).
567
+ template <typename TupleLike, size_t arity>
568
+ struct check_types <TupleLike, arity, arity> {
569
+ constexpr static inline bool check () {
570
+ if constexpr (arity != 0 )
571
+ return true ;
572
+ return false ;
573
+ }
574
+ };
575
+
576
+ template <typename TupleLike>
577
+ struct check_types <TupleLike, 0 , 0 > {
578
+ constexpr static inline bool check () {
579
+ return false ;
580
+ }
581
+ };
582
+ } // namespace
583
+ #endif
584
+
395
585
template <typename func_t >
396
586
void gpu_kernel_impl (TensorIteratorBase& iter, const func_t & f) {
397
587
if (!needs_dynamic_casting<func_t >::check (iter)) {
@@ -416,6 +606,45 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
416
606
417
607
if (contiguous) {
418
608
#ifdef USE_ROCM
609
+ // Attempt to call specialized vectorized elementwise kernel
610
+ // that enables interleaving.
611
+ using float_map = c10::CppTypeToScalarType<float >;
612
+ using bfloat16_map = c10::CppTypeToScalarType<BFloat16>;
613
+ if (iter.ninputs () == 2 && iter.input_dtype (0 ) == float_map::value &&
614
+ iter.input_dtype (1 ) == bfloat16_map::value &&
615
+ memory::can_vectorize_up_to<func_t >(data) > 1 ) {
616
+ // constexpr to reduce the amount of kernels (empty) generated for
617
+ // vectorized templated elementwise and limit which functors are actually
618
+ // applied to the load and store at compile time.
619
+ using func_tuple = typename traits::ArgsTuple;
620
+ if constexpr (
621
+ std::is_same_v<float , arg0_t > && traits::arity == 2 &&
622
+ check_types<func_tuple, traits::arity, 0 >::check ()) {
623
+ auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
624
+ auto output_offset_calculator = TrivialOffsetCalculator<1 >();
625
+ auto loader = memory::LoadWithCast<traits::arity>(iter);
626
+ auto storer = memory::StoreWithCast<1 >(iter);
627
+ launch_vectorized_templated_kernel<
628
+ func_t ,
629
+ std::array<char *, ntensors>,
630
+ decltype (input_offset_calculator),
631
+ decltype (output_offset_calculator),
632
+ decltype (loader),
633
+ decltype (storer),
634
+ float ,
635
+ float ,
636
+ BFloat16>(
637
+ numel,
638
+ f,
639
+ data,
640
+ input_offset_calculator,
641
+ output_offset_calculator,
642
+ loader,
643
+ storer);
644
+ return ;
645
+ }
646
+ }
647
+
419
648
std::array<ScalarType, ntensors> dtypes;
420
649
auto inner_strides = iter.get_inner_strides ();
421
650
std::array<int , ntensors> strides;
0 commit comments