@@ -37,9 +37,12 @@ __global__ void write_indices(
37
37
int64_t * inp,
38
38
TensorDims<index_t > dims,
39
39
int ndim,
40
- index_t n) {
41
- auto index = threadIdx .x + blockIdx .x * blockDim .x ;
42
- if (index < n) {
40
+ index_t n,
41
+ int64_t * total = nullptr ,
42
+ int64_t fill_value = -1 ) {
43
+ auto index = threadIdx .x + (int64_t )blockIdx .x * blockDim .x ;
44
+ bool cond = (total == nullptr || index < *total);
45
+ if (index < n && cond) {
43
46
index_t div = 1 ;
44
47
int64_t idx_flat = inp[index];
45
48
#pragma unroll
@@ -50,9 +53,117 @@ __global__ void write_indices(
50
53
inp[index + dim * n] = (idx_flat / div) % dim_size;
51
54
div *= dim_size;
52
55
}
56
+ } else if (index < n) {
57
+ // 0th dim has correct values already
58
+ for (int dim = ndim - 1 ; dim > 0 ; dim--) {
59
+ inp[index + dim * n] = fill_value;
60
+ }
61
+ }
62
+ }
63
+
64
+ __global__ void write_fill_value (int64_t * inp, int64_t * total, int64_t fill_value, int64_t n){
65
+ int64_t total_val = *total;
66
+ // not aiming for vectorized stores
67
+
68
+ for (int64_t idx = total_val + (int64_t )blockIdx .x * blockDim .x + threadIdx .x ; idx < n; idx += blockDim .x * gridDim .x ) {
69
+ inp[idx] = fill_value;
53
70
}
54
71
}
55
72
73
+ template <int BLOCK_THREADS>
74
+ __global__ void compute_agg (int32_t * agg, int64_t * agg_cum, uint32_t n_blocks) {
75
+
76
+ using BlockScanT = ROCM_HIPCUB (at_cuda_detail::cub)::BlockScan<int64_t , BLOCK_THREADS, ROCM_HIPCUB (at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
77
+ __shared__ typename BlockScanT::TempStorage temp_storage;
78
+ int agg_data;
79
+ int64_t agg_cum_data;
80
+ agg_data = threadIdx .x < n_blocks ? agg[threadIdx .x ] : 0 ;
81
+ BlockScanT (temp_storage).InclusiveSum (agg_data, agg_cum_data);
82
+ if (threadIdx .x < n_blocks) {
83
+ agg_cum[threadIdx .x ] = agg_cum_data;
84
+ }
85
+ }
86
+
87
+ template <int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
88
+ __global__ void flag_kernel (const T* d_in, int64_t * d_out, const int64_t * agg, int64_t input_nelem, int64_t output_nelem, int iters_per_cta) {
89
+ int64_t start_idx = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t )blockIdx .x ;
90
+ if (start_idx >= input_nelem) return ;
91
+ d_in += start_idx;
92
+
93
+ using BlockLoadT = ROCM_HIPCUB (at_cuda_detail::cub)::BlockLoad<int , BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB (at_cuda_detail::cub)::BLOCK_LOAD_WARP_TRANSPOSE>;
94
+
95
+ // Specialize BlockScan type for our thread block
96
+ using BlockScanT = ROCM_HIPCUB (at_cuda_detail::cub)::BlockScan<int , BLOCK_THREADS, ROCM_HIPCUB (at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
97
+ using TransformInputIteratorT = ROCM_HIPCUB (at_cuda_detail::cub)::TransformInputIterator<int , NonZeroOp<T>, const T*>;
98
+ using BlockExchangeT = ROCM_HIPCUB (at_cuda_detail::cub)::BlockExchange<int , BLOCK_THREADS, ITEMS_PER_THREAD>;
99
+
100
+ // Shared memory
101
+ __shared__ union TempStorage
102
+ {
103
+ typename BlockLoadT::TempStorage load;
104
+ typename BlockScanT::TempStorage scan;
105
+ typename BlockExchangeT::TempStorage exchange;
106
+ } temp_storage;
107
+
108
+ int64_t aggregate = blockIdx .x == 0 ? 0 : agg[blockIdx .x - 1 ];
109
+ d_out += aggregate;
110
+
111
+ TransformInputIteratorT t_input_itr (d_in, NonZeroOp<T>());
112
+
113
+ // Per-thread tile data
114
+ int data[ITEMS_PER_THREAD];
115
+ int out_indices[ITEMS_PER_THREAD];
116
+
117
+ int64_t remaining = input_nelem - start_idx;
118
+ int64_t out_remaining = output_nelem - aggregate;
119
+ for (int i=0 ; i<iters_per_cta; i++){
120
+
121
+ // Load items into a blocked arrangement
122
+ if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
123
+ BlockLoadT (temp_storage.load ).Load (t_input_itr, data);
124
+ } else {
125
+ BlockLoadT (temp_storage.load ).Load (t_input_itr, data, remaining, int (0 ));
126
+ }
127
+
128
+ // Barrier for smem reuse
129
+ __syncthreads ();
130
+
131
+ // Compute inclusive prefix sum
132
+ int aggregate;
133
+ __shared__ int aggregate_sh;
134
+ BlockScanT (temp_storage.scan ).ExclusiveSum (data, out_indices, aggregate);
135
+
136
+ if (threadIdx .x == 0 ){
137
+ aggregate_sh = aggregate;
138
+ }
139
+
140
+ // Barrier for smem reuse
141
+ __syncthreads ();
142
+ // striped arrangement will provide a slightly better
143
+ // coalescing for writes (although it's still bad because it's indirect indexing)
144
+ BlockExchangeT (temp_storage.exchange ).BlockedToStriped (data);
145
+ __syncthreads ();
146
+ BlockExchangeT (temp_storage.exchange ).BlockedToStriped (out_indices);
147
+ for (int ii=0 ; ii<ITEMS_PER_THREAD; ii++){
148
+ if (data[ii] != 0 && out_indices[ii] < out_remaining) {
149
+ int64_t inp_idx = start_idx + threadIdx .x + blockDim .x * ii;
150
+ d_out[out_indices[ii]] = inp_idx;
151
+ }
152
+ }
153
+
154
+ out_remaining -= aggregate_sh;
155
+ remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
156
+ if (remaining <= 0 || out_remaining <= 0 ) return ;
157
+ d_out += aggregate_sh;
158
+ t_input_itr += BLOCK_THREADS * ITEMS_PER_THREAD;
159
+ start_idx += BLOCK_THREADS * ITEMS_PER_THREAD;
160
+ __syncthreads ();
161
+ }
162
+
163
+ }
164
+
165
+
166
+
56
167
} // anonymous namespace
57
168
58
169
template <typename scalar_t >
@@ -183,6 +294,83 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
183
294
}
184
295
}
185
296
297
+ template <typename scalar_t >
298
+ void nonzero_static_cuda_out_impl (
299
+ const Tensor& self,
300
+ int64_t size,
301
+ int64_t fill_value,
302
+ Tensor& out) {
303
+ # if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM)
304
+
305
+ Tensor self_contiguous_ = self.contiguous ();
306
+ // see comment in nonzero_cuda_out_impl on reqs for out
307
+ bool out_correct_size =
308
+ out.dim () == 2 && out.sizes ()[0 ] == size && out.sizes ()[1 ] == self.dim ();
309
+ bool need_to_copy = out_correct_size && !out.t ().is_contiguous ();
310
+ if (!out_correct_size) {
311
+ out.resize_ ({self.dim (), size}).t ();
312
+ }
313
+ if (out.numel () == 0 ) return ;
314
+ // we need to allocate temporary out to then copy to user provided out
315
+ at::Tensor out_temp;
316
+ if (need_to_copy) {
317
+ out_temp =
318
+ Tensor (at::detail::empty_cuda ({self.dim (), size}, out.options ())).t ();
319
+ }
320
+ int64_t * out_data_ptr = need_to_copy ? out_temp.mutable_data_ptr <int64_t >()
321
+ : out.mutable_data_ptr <int64_t >();
322
+
323
+ const scalar_t * in_data_ptr = self_contiguous_.const_data_ptr <scalar_t >();
324
+ constexpr int BLOCK_THREADS = 512 ; // block_threads<sizeof(scalar_t)>();
325
+ constexpr int ITEMS_PER_THREAD = 16 ;
326
+ auto grid_size = (self.numel () + BLOCK_THREADS * ITEMS_PER_THREAD - 1 ) / (BLOCK_THREADS * ITEMS_PER_THREAD);
327
+ const int64_t num_sms = at::cuda::getCurrentDeviceProperties ()->multiProcessorCount ;
328
+ int64_t target_blocks = sizeof (scalar_t ) == 1 ? 2 * num_sms : num_sms;
329
+ const int iters_per_cta = (grid_size + target_blocks - 1 )/target_blocks;
330
+ grid_size = (self.numel () + iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD - 1 ) / (iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD);
331
+ auto & allocator = *c10::cuda::CUDACachingAllocator::get ();
332
+ auto agg = allocator.allocate (grid_size * sizeof (int ));
333
+ at::cuda::cub::calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD, true >
334
+ <<<grid_size, BLOCK_THREADS, 0 , at::cuda::getCurrentCUDAStream()>>> (
335
+ in_data_ptr, (int *)agg.get (), self.numel (), iters_per_cta);
336
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
337
+ auto agg_cum = allocator.allocate (grid_size * sizeof (int64_t ));
338
+ // computing partial sums in int64 in the flag kernel
339
+ // leads to 20-30% slowdown, so compute them in a separate 2 us kernel
340
+ compute_agg<BLOCK_THREADS><<<1 , BLOCK_THREADS, 0 , at::cuda::getCurrentCUDAStream()>>> (
341
+ (int *)agg.get (), (int64_t *)agg_cum.get (), grid_size
342
+ );
343
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
344
+ flag_kernel<BLOCK_THREADS, ITEMS_PER_THREAD>
345
+ <<<grid_size, BLOCK_THREADS, 0 , at::cuda::getCurrentCUDAStream()>>> (
346
+ in_data_ptr, out_data_ptr, (int64_t *)agg_cum.get (), self.numel (), size, iters_per_cta);
347
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
348
+ int64_t out_grid = std::min (num_sms, (size + BLOCK_THREADS - 1 )/BLOCK_THREADS);
349
+ write_fill_value<<<out_grid, BLOCK_THREADS, 0 , at::cuda::getCurrentCUDAStream()>>> (out_data_ptr, (int64_t *)agg_cum.get () + grid_size - 1 , fill_value, size);
350
+ if (self.dim () > 1 ) {
351
+ TensorDims<int64_t > dims;
352
+ for (int i = 0 ; i < self.dim (); i++) {
353
+ dims.sizes [i] = self.sizes ()[i];
354
+ }
355
+ const int nthreads = 256 ;
356
+ const int nblocks = (size + nthreads - 1 ) / nthreads;
357
+ write_indices<<<nblocks, nthreads, 0 , at::cuda::getCurrentCUDAStream()>>> (
358
+ out_data_ptr,
359
+ dims,
360
+ self.dim (),
361
+ size,
362
+ (int64_t *)agg_cum.get () + grid_size - 1 ,
363
+ fill_value);
364
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
365
+ }
366
+ if (need_to_copy) {
367
+ out.copy_ (out_temp);
368
+ }
369
+ #else
370
+ TORCH_CHECK (false , " Nonzero_static is not supported for cuda <= 11.4" );
371
+ #endif
372
+ }
373
+
186
374
Tensor& nonzero_out_cuda (const Tensor& self, Tensor& out) {
187
375
TORCH_CHECK (
188
376
out.dtype () == at::kLong ,
@@ -216,4 +404,56 @@ Tensor nonzero_cuda(const Tensor& self) {
216
404
Tensor out = at::detail::empty_cuda ({0 }, self.options ().dtype (kLong ));
217
405
return at::native::nonzero_out_cuda (self, out);
218
406
}
407
+
408
+ Tensor& nonzero_static_out_cuda (
409
+ const Tensor& self,
410
+ int64_t size,
411
+ int64_t fill_value,
412
+ Tensor& out) {
413
+ TORCH_CHECK (
414
+ out.dtype () == at::kLong ,
415
+ " nonzero_static: Expected out tensor to have scalar type " ,
416
+ at::kLong ,
417
+ " but got " ,
418
+ out.dtype ());
419
+ TORCH_CHECK (
420
+ self.device () == out.device (),
421
+ " expected self and out to be on the same device, but got out on " ,
422
+ out.device (),
423
+ " and self on " ,
424
+ self.device ());
425
+ TORCH_CHECK (
426
+ self.dim () <= MAX_DIMS,
427
+ " nonzero_static is not supported for tensor with more than " ,
428
+ MAX_DIMS,
429
+ " dimensions" );
430
+ TORCH_CHECK (
431
+ size >= 0 , " nonzero_static: 'size' must be an non-negative integer"
432
+ )
433
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4 (
434
+ at::ScalarType::ComplexHalf,
435
+ at::ScalarType::Bool,
436
+ at::ScalarType::BFloat16,
437
+ at::ScalarType::Half,
438
+ self.scalar_type (),
439
+ " nonzero_cuda" ,
440
+ [&] {
441
+ nonzero_static_cuda_out_impl<scalar_t >(self, size, fill_value, out);
442
+ });
443
+ return out;
444
+ }
445
+
446
+ Tensor nonzero_static_cuda (
447
+ const Tensor& self,
448
+ int64_t size,
449
+ int64_t fill_value) {
450
+ TORCH_CHECK (
451
+ size >= 0 , " nonzero_static: 'size' must be an non-negative integer"
452
+ )
453
+ Tensor out = Tensor (at::detail::empty_cuda (
454
+ {self.dim (), size}, self.options ().dtype (kLong )))
455
+ .t ();
456
+ return at::native::nonzero_static_out_cuda (self, size, fill_value, out);
457
+ }
458
+
219
459
} // namespace at::native
0 commit comments