10000 check the sm version · pytorch/pytorch@b0c21f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit b0c21f7

Browse files
committed
check the sm version
1 parent d22c4cc commit b0c21f7

File tree

1 file changed

+71
-6
lines changed

1 file changed

+71
-6
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,71 @@ constexpr auto calc_io_size(){
150150
#endif
151151
}
152152

153+
#ifndef USE_ROCM
154+
// To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
155+
// into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
156+
// used on sm_90 and sm_100 exclusively.
157+
template <int vec_size, typename func_t, typename array_t>
158+
C10_LAUNCH_BOUNDS_1(num_threads())
159+
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
160+
if constexpr (vec_size == 8) {
161+
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
162+
using traits = function_traits<func_t>;
163+
constexpr auto io_size = calc_io_size<func_t>();
164+
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
165+
166+
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
167+
// just do a naive unrolled loop
168+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
169+
auto output_calc = TrivialOffsetCalculator<1>();
170+
auto loader = memory::LoadWithoutCast();
171+
auto storer = memory::StoreWithoutCast();
172+
auto policy = memory::policies::unroll<
173+
array_t,
174+
decltype(input_calc),
175+
decltype(output_calc),
176+
memory::LoadWithoutCast,
177+
memory::StoreWithoutCast,
178+
elems_per_thread<io_size>()>(
179+
data, remaining, input_calc, output_calc, loader, storer);
180+
elementwise_kernel_helper(f, policy);
181+
} else { // if this block has a full `block_work_size` data to handle, use
182+
// vectorized memory access
183+
constexpr auto optimal_vec_size = vec_size;
184+
elementwise_kernel_helper(
185+
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
186+
}
187+
#endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
188+
} else {
189+
using traits = function_traits<func_t>;
190+
constexpr auto io_size = calc_io_size<func_t>();
191+
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
192+
193+
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
194+
// just do a naive unrolled loop
195+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
196+
auto output_calc = TrivialOffsetCalculator<1>();
197+
auto loader = memory::LoadWithoutCast();
198+
auto storer = memory::StoreWithoutCast();
199+
auto policy = memory::policies::unroll<
200+
array_t,
201+
decltype(input_calc),
202+
decltype(output_calc),
203+
memory::LoadWithoutCast,
204+
memory::StoreWithoutCast,
205+
elems_per_thread<io_size>()>(
206+
data, remaining, input_calc, output_calc, loader, storer);
207+
elementwise_kernel_helper(f, policy);
208+
} else { // if this block has a full `block_work_size` data to handle, use
209+
// vectorized memory access
210+
constexpr auto optimal_vec_size = vec_size;
211+
elementwise_kernel_helper(
212+
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
213+
}
214+
}
215+
}
216+
217+
#else // USE_ROCM
153218
template <int vec_size, typename func_t, typename array_t>
154219
C10_LAUNCH_BOUNDS_1(num_threads())
155220
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
@@ -174,15 +239,10 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
174239
elementwise_kernel_helper(f, policy);
175240
} else { // if this block has a full `block_work_size` data to handle, use
176241
// vectorized memory access
177-
#ifdef USE_ROCM
178242
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
179-
#else
180-
constexpr auto optimal_vec_size = vec_size;
181-
#endif
182-
elementwise_kernel_helper(
183-
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
184243
}
185244
}
245+
#endif // USE_ROCM
186246

187247
template <
188248
typename func_t,
@@ -229,6 +289,11 @@ static inline void launch_vectorized_kernel(
229289
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
230290
// that causes some numerical mismatches with uint8 on sm80 and sm90.
231291
// TODO: Revisit this after CUDA 12.8 update.
292+
cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index());
293+
const int computeCapability = p->major * 10 + p->minor;
294+
if (computeCapability != 90 && computeCapability != 100) {
295+
vec_size = std::min<uint16_t>(vec_size, 4);
296+
}
232297
if constexpr (sizeof(cpp_type) < 2) {
233298
vec_size = std::min<uint16_t>(vec_size, 4);
234299
}

0 commit comments

Comments
 (0)
0