@@ -150,6 +150,71 @@ constexpr auto calc_io_size(){
150
150
#endif
151
151
}
152
152
153
+ #ifndef USE_ROCM
154
+ // To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
155
+ // into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
156
+ // used on sm_90 and sm_100 exclusively.
157
+ template <int vec_size, typename func_t , typename array_t >
158
+ C10_LAUNCH_BOUNDS_1 (num_threads())
159
+ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
160
+ if constexpr (vec_size == 8 ) {
161
+ #if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
162
+ using traits = function_traits<func_t >;
163
+ constexpr auto io_size = calc_io_size<func_t >();
164
+ int remaining = N - io_block_work_size<io_size>() * blockIdx .x ;
165
+
166
+ if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
167
+ // just do a naive unrolled loop
168
+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
169
+ auto output_calc = TrivialOffsetCalculator<1 >();
170
+ auto loader = memory::LoadWithoutCast ();
171
+ auto storer = memory::StoreWithoutCast ();
172
+ auto policy = memory::policies::unroll<
173
+ array_t ,
174
+ decltype (input_calc),
175
+ decltype (output_calc),
176
+ memory::LoadWithoutCast,
177
+ memory::StoreWithoutCast,
178
+ elems_per_thread<io_size>()>(
179
+ data, remaining, input_calc, output_calc, loader, storer);
180
+ elementwise_kernel_helper (f, policy);
181
+ } else { // if this block has a full `block_work_size` data to handle, use
182
+ // vectorized memory access
183
+ constexpr auto optimal_vec_size = vec_size;
184
+ elementwise_kernel_helper (
185
+ f, memory::policies::vectorized<optimal_vec_size, array_t , elems_per_thread<io_size>()>(data));
186
+ }
187
+ #endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
188
+ } else {
189
+ using traits = function_traits<func_t >;
190
+ constexpr auto io_size = calc_io_size<func_t >();
191
+ int remaining = N - io_block_work_size<io_size>() * blockIdx .x ;
192
+
193
+ if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
194
+ // just do a naive unrolled loop
195
+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
196
+ auto output_calc = TrivialOffsetCalculator<1 >();
197
+ auto loader = memory::LoadWithoutCast ();
198
+ auto storer = memory::StoreWithoutCast ();
199
+ auto policy = memory::policies::unroll<
200
+ array_t ,
201
+ decltype (input_calc),
202
+ decltype (output_calc),
203
+ memory::LoadWithoutCast,
204
+ memory::StoreWithoutCast,
205
+ elems_per_thread<io_size>()>(
206
+ data, remaining, input_calc, output_calc, loader, storer);
207
+ elementwise_kernel_helper (f, policy);
208
+ } else { // if this block has a full `block_work_size` data to handle, use
209
+ // vectorized memory access
210
+ constexpr auto optimal_vec_size = vec_size;
211
+ elementwise_kernel_helper (
212
+ f, memory::policies::vectorized<optimal_vec_size, array_t , elems_per_thread<io_size>()>(data));
213
+ }
214
+ }
215
+ }
216
+
217
+ #else // USE_ROCM
153
218
template <int vec_size, typename func_t , typename array_t >
154
219
C10_LAUNCH_BOUNDS_1 (num_threads())
155
220
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
@@ -174,15 +239,10 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
174
239
elementwise_kernel_helper (f, policy);
175
240
} else { // if this block has a full `block_work_size` data to handle, use
176
241
// vectorized memory access
177
- #ifdef USE_ROCM
178
242
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
179
- #else
180
- constexpr auto optimal_vec_size = vec_size;
181
- #endif
182
- elementwise_kernel_helper (
183
- f, memory::policies::vectorized<optimal_vec_size, array_t , elems_per_thread<io_size>()>(data));
184
243
}
185
244
}
245
+ #endif // USE_ROCM
186
246
187
247
template <
188
248
typename func_t ,
@@ -229,6 +289,11 @@ static inline void launch_vectorized_kernel(
229
289
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
230
290
// that causes some numerical mismatches with uint8 on sm80 and sm90.
231
291
// TODO: Revisit this after CUDA 12.8 update.
292
+ cudaDeviceProp* p = at::cuda::getDeviceProperties (stream.device ().index ());
293
+ const int computeCapability = p->major * 10 + p->minor ;
294
+ if (computeCapability != 90 && computeCapability != 100 ) {
295
+ vec_size = std::min<uint16_t >(vec_size, 4 );
296
+ }
232
297
if constexpr (sizeof (cpp_type) < 2 ) {
233
298
vec_size = std::min<uint16_t >(vec_size, 4 );
234
299
}
0 commit comments