File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed
aten/src/ATen/native/cuda Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -226,8 +226,9 @@ C10_LAUNCH_BOUNDS_1(num_threads())
226226__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
227227 using traits = function_traits<func_t >;
228228 constexpr auto io_size = calc_io_size<func_t >();
229- #ifdef __gfx942__
230- constexpr int tws = (io_size >= 2 ) ? 8 : 16 ;
229+ #if defined(USE_ROCM) && defined(__gfx942__)
230+ // Similar check in launch_vectorized_kernel() as well. Both should be in sync.
231+ constexpr int tws = 16 ;
231232#else
232233 constexpr int tws = elems_per_thread<io_size>();
233234#endif
@@ -296,7 +297,8 @@ static inline void launch_vectorized_kernel(
296297 int vec_size = memory::can_vectorize_up_to<func_t >(data);
297298 c10::DeviceIndex curDevice = -1 ;
298299 AT_CUDA_CHECK (c10::cuda::GetDevice (&curDevice));
299- int tws = at::detail::getCUDAHooks ().isGPUArch ({" gfx942" }, curDevice) ? ((io_size >= 2 ) ? 8 : 16 ) : elems_per_thread<io_size>();
300+ // Similar check in vectorized_elementwise_kernel() as well. Both should be in sync.
301+ int tws = at::detail::getCUDAHooks ().isGPUArch ({" gfx942" }, curDevice) ? 16 : elems_per_thread<io_size>();
300302#else
301303 using cpp_type = typename function_traits<func_t >::result_type;
302304 const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t >(data);
You can’t perform that action at this time.
0 commit comments