8000 Update · pytorch/pytorch@f3d6c5a · GitHub
[go: up one dir, main page]

Skip to content

Commit f3d6c5a

Browse files
committed
Update
[ghstack-poisoned]
2 parents 334c9b0 + 0dac3f6 commit f3d6c5a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+985
-657
lines changed

aten/src/ATen/TensorUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,14 +372,14 @@ inline std::optional<ResultVec> computeStride_impl(
372372
// if end of tensor size chunk, check view
373373
if ((tensor_d == 0) ||
374374
(TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldshape[tensor_d - 1], 1)) &&
375-
oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
375+
TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) {
376376
while (view_d >= 0 &&
377377
(TORCH_GUARD_SIZE_OBLIVIOUS(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(newshape[view_d], 1)))) {
378378
newstride[view_d] = view_numel * chunk_base_stride;
379379
view_numel *= newshape[view_d];
380380
view_d--;
381381
}
382-
if (view_numel != tensor_numel) {
382+
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(view_numel, tensor_numel))) {
383383
return std::nullopt;
384384
}
385385
if (tensor_d > 0) {

aten/src/ATen/cuda/cub.cuh

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem
349349
// Per-thread tile data
350350
T data[ITEMS_PER_THREAD];
351351

352-
int remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
352+
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
353353
for (int i=0; i<iters_per_cta; i++){
354354
// Load items into a blocked arrangement
355355
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
@@ -386,38 +386,57 @@ __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem
386386

387387
}
388388

389+
template <typename T, typename aggT, bool nonzero>
390+
struct TransformFunctor {
391+
__device__ aggT operator()(T value) const {
392+
if constexpr (!nonzero) {
393+
return value;
394+
} else {
395+
return (value != T(0)) ? 1 : 0;
396+
}
397+
}
398+
};
389399

390-
391-
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
392-
__global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iters_per_cta){
400+
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, bool nonzero, typename T, typename aggT>
401+
__global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int iters_per_cta){
393402
if (BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x >= nelem) return;
394-
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
403+
d_in += BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
395404

396-
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
397-
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<T, BLOCK_THREADS>;
405+
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<aggT, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
406+
using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<aggT, BLOCK_THREADS>;
398407
// Shared memory
399408
__shared__ union TempStorage
400409
{
401410
typename BlockLoadT::TempStorage load;
402411
typename BlockReduceT::TempStorage reduce;
403412
} temp_storage;
404-
T data[ITEMS_PER_THREAD];
405-
T agg_val = 0;
406-
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * blockIdx.x;
413+
aggT data[ITEMS_PER_THREAD];
414+
aggT agg_val = 0;
415+
int64_t remaining = nelem - BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
416+
TransformFunctor<T, aggT, nonzero> transform_functor;
417+
auto iter_in = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator<aggT, TransformFunctor<T, aggT, nonzero>, const T*>(d_in, transform_functor);
407418
for (int i=0; i<iters_per_cta; i++){
408419
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
409-
BlockLoadT(temp_storage.load).Load(d_in, data);
420+
BlockLoadT(temp_storage.load).Load(iter_in, data);
410421
__syncthreads();
411422
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
412423

413424
} else {
414-
BlockLoadT(temp_storage.load).Load(d_in, data, remaining);
425+
BlockLoadT(temp_storage.load).Load(iter_in, data, remaining, aggT(0));
415426
__syncthreads();
416427
agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
417428
}
418-
d_in += BLOCK_THREADS * ITEMS_PER_THREAD;
429+
iter_in += BLOCK_THREADS * ITEMS_PER_THREAD;
419430
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
420-
if (remaining <= 0) return;
431+
if (remaining <= 0) {
432+
// for nonzeros we need to write out last blocks
433+
// accumulated value to be able to compute
434+
// total number of nonzeros
435+
if (nonzero && threadIdx.x == 0) {
436+
agg[blockIdx.x] = agg_val;
437+
}
438+
return;
439+
}
421440
__syncthreads();
422441

423442
}
@@ -427,6 +446,13 @@ __global__ void calc_block_sums(const T * d_in, T * agg, int64_t nelem, int iter
427446

428447
}
429448

449+
template <typename T>
450+
struct NonZeroOp {
451+
__host__ __device__ __forceinline__ int operator()(const T& a) const {
452+
return (a != T(0));
453+
}
454+
};
455+
430456
template<int size>
431457
constexpr int block_threads(){
432458
if constexpr (size >=16) {
@@ -450,7 +476,7 @@ inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * out
450476
grid_size = std::min(num_sms, grid_size);
451477
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
452478
auto agg = allocator.allocate(grid_size * sizeof(scalar_t));
453-
calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD>
479+
calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD, false>
454480
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
455481
input, (scalar_t*)agg.get(), num_items, iters_per_cta);
456482
C10_CUDA_KERNEL_LAUNCH_CHECK();

aten/src/ATen/native/cuda/Nonzero.cu

Lines changed: 243 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ __global__ void write_indices(
3737
int64_t* inp,
3838
TensorDims<index_t> dims,
3939
int ndim,
40-
index_t n) {
41-
auto index = threadIdx.x + blockIdx.x * blockDim.x;
42-
if (index < n) {
40+
index_t n,
41+
int64_t * total = nullptr,
42+
int64_t fill_value = -1) {
43+
auto index = threadIdx.x + (int64_t)blockIdx.x * blockDim.x;
44+
bool cond = (total == nullptr || index < *total);
45+
if (index < n && cond) {
4346
index_t div = 1;
4447
int64_t idx_flat = inp[index];
4548
#pragma unroll
@@ -50,9 +53,117 @@ __global__ void write_indices(
5053
inp[index + dim * n] = (idx_flat / div) % dim_size;
5154
div *= dim_size;
5255
}
56+
} else if (index < n) {
57+
// 0th dim has correct values already
58+
for (int dim = ndim - 1; dim > 0; dim--) {
59+
inp[index + dim * n] = fill_value;
60+
}
61+
}
62+
}
63+
64+
__global__ void write_fill_value(int64_t * inp, int64_t * total, int64_t fill_value, int64_t n){
65+
int64_t total_val = *total;
66+
// not aiming for vectorized stores
67+
68+
for (int64_t idx = total_val + (int64_t)blockIdx.x * blockDim.x + threadIdx.x; idx < n; idx += blockDim.x * gridDim.x) {
69+
inp[idx] = fill_value;
5370
}
5471
}
5572

73+
template <int BLOCK_THREADS>
74+
__global__ void compute_agg(int32_t * agg, int64_t * agg_cum, uint32_t n_blocks) {
75+
76+
using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<int64_t, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
77+
__shared__ typename BlockScanT::TempStorage temp_storage;
78+
int agg_data;
79+
int64_t agg_cum_data;
80+
agg_data = threadIdx.x < n_blocks ? agg[threadIdx.x] : 0;
81+
BlockScanT(temp_storage).InclusiveSum(agg_data, agg_cum_data);
82+
if (threadIdx.x < n_blocks) {
83+
agg_cum[threadIdx.x] = agg_cum_data;
84+
}
85+
}
86+
87+
template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
88+
__global__ void flag_kernel(const T* d_in, int64_t * d_out, const int64_t * agg, int64_t input_nelem, int64_t output_nelem, int iters_per_cta) {
89+
int64_t start_idx = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
90+
if (start_idx >= input_nelem) return;
91+
d_in += start_idx;
92+
93+
using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_WARP_TRANSPOSE>;
94+
95+
// Specialize BlockScan type for our thread block
96+
using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<int, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
97+
using TransformInputIteratorT = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator<int, NonZeroOp<T>, const T*>;
98+
using BlockExchangeT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockExchange<int, BLOCK_THREADS, ITEMS_PER_THREAD>;
99+
100+
// Shared memory
101+
__shared__ union TempStorage
102+
{
103+
typename BlockLoadT::TempStorage load;
104+
typename BlockScanT::TempStorage scan;
105+
typename BlockExchangeT::TempStorage exchange;
106+
} temp_storage;
107+
108+
int64_t aggregate = blockIdx.x == 0 ? 0 : agg[blockIdx.x - 1];
109+
d_out += aggregate;
110+
111+
TransformInputIteratorT t_input_itr(d_in, NonZeroOp<T>());
112+
113+
// Per-thread tile data
114+
int data[ITEMS_PER_THREAD];
115+
int out_indices[ITEMS_PER_THREAD];
116+
117+
int64_t remaining = input_nelem - start_idx;
118+
int64_t out_remaining = output_nelem - aggregate;
119+
for (int i=0; i<iters_per_cta; i++){
120+
121+
// Load items into a blocked arrangement
122+
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
123+
BlockLoadT(temp_storage.load).Load(t_input_itr, data);
124+
} else {
125+
BlockLoadT(temp_storage.load).Load(t_input_itr, data, remaining, int(0));
126+
}
127+
128+
// Barrier for smem reuse
129+
__syncthreads();
130+
131+
// Compute inclusive prefix sum
132+
int aggregate;
133+
__shared__ int aggregate_sh;
134+
BlockScanT(temp_storage.scan).ExclusiveSum(data, out_indices, aggregate);
135+
136+
if (threadIdx.x == 0){
137+
aggregate_sh = aggregate;
138+
}
139+
140+
// Barrier for smem reuse
141+
__syncthreads();
142+
// striped arrangement will provide a slightly better
143+
// coalescing for writes (although it's still bad because it's indirect indexing)
144+
BlockExchangeT(temp_storage.exchange).BlockedToStriped(data);
145+
__syncthreads();
146+
BlockExchangeT(temp_storage.exchange).BlockedToStriped(out_indices);
147+
for (int ii=0; ii<ITEMS_PER_THREAD; ii++){
148+
if (data[ii] != 0 && out_indices[ii] < out_remaining) {
149+
int64_t inp_idx = start_idx + threadIdx.x + blockDim.x * ii;
150+
d_out[out_indices[ii]] = inp_idx;
151+
}
152+
}
153+
154+
out_remaining -= aggregate_sh;
155+
remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
156+
if (remaining <= 0 || out_remaining <= 0) return;
157+
d_out += aggregate_sh;
158+
t_input_itr += BLOCK_THREADS * ITEMS_PER_THREAD;
159+
start_idx += BLOCK_THREADS * ITEMS_PER_THREAD;
160+
__syncthreads();
161+
}
162+
163+
}
164+
165+
166+
56167
} // anonymous namespace
57168

58169
template <typename scalar_t>
@@ -183,6 +294,83 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
183294
}
184295
}
185296

297+
template <typename scalar_t>
298+
void nonzero_static_cuda_out_impl(
299+
const Tensor& self,
300+
int64_t size,
301+
int64_t fill_value,
302+
Tensor& out) {
303+
# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM)
304+
305+
Tensor self_contiguous_ = self.contiguous();
306+
// see comment in nonzero_cuda_out_impl on reqs for out
307+
bool out_correct_size =
308+
out.dim() == 2 && out.sizes()[0] == size && out.sizes()[1] == self.dim();
309+
bool need_to_copy = out_correct_size && !out.t().is_contiguous();
310+
if (!out_correct_size) {
311+
out.resize_({self.dim(), size}).t();
312+
}
313+
if (out.numel() == 0) return;
314+
// we need to allocate temporary out to then copy to user provided out
315+
at::Tensor out_temp;
316+
if (need_to_copy) {
317+
out_temp =
318+
Tensor(at::detail::empty_cuda({self.dim(), size}, out.options())).t();
319+
}
320+
int64_t* out_data_ptr = need_to_copy ? out_temp.mutable_data_ptr<int64_t>()
321+
: out.mutable_data_ptr<int64_t>();
322+
323+
const scalar_t * in_data_ptr = self_contiguous_.const_data_ptr<scalar_t>();
324+
constexpr int BLOCK_THREADS = 512; //block_threads<sizeof(scalar_t)>();
325+
constexpr int ITEMS_PER_THREAD = 16;
326+
auto grid_size = (self.numel() + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD);
327+
const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
328+
int64_t target_blocks = sizeof(scalar_t) == 1 ? 2 * num_sms : num_sms;
329+
const int iters_per_cta = (grid_size + target_blocks - 1)/target_blocks;
330+
grid_size = (self.numel() + iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (iters_per_cta * BLOCK_THREADS * ITEMS_PER_THREAD);
331+
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
332+
auto agg = allocator.allocate(grid_size * sizeof(int));
333+
at::cuda::cub::calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD, true>
334+
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
335+
in_data_ptr, (int*)agg.get(), self.numel(), iters_per_cta);
336+
C10_CUDA_KERNEL_LAUNCH_CHECK();
337+
auto agg_cum = allocator.allocate(grid_size * sizeof(int64_t));
338+
// computing partial sums in int64 in the flag kernel
339+
// leads to 20-30% slowdown, so compute them in a separate 2 us kernel
340+
compute_agg<BLOCK_THREADS><<<1, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
341+
(int*)agg.get(), (int64_t*)agg_cum.get(), grid_size
342+
);
343+
C10_CUDA_KERNEL_LAUNCH_CHECK();
344+
flag_kernel<BLOCK_THREADS, ITEMS_PER_THREAD>
345+
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
346+
in_data_ptr, out_data_ptr, (int64_t*)agg_cum.get(), self.numel(), size, iters_per_cta);
347+
C10_CUDA_KERNEL_LAUNCH_CHECK();
348+
int64_t out_grid = std::min(num_sms, (size + BLOCK_THREADS - 1)/BLOCK_THREADS);
349+
write_fill_value<<<out_grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(out_data_ptr, (int64_t *)agg_cum.get() + grid_size - 1, fill_value, size);
350+
if (self.dim() > 1) {
351+
TensorDims<int64_t> dims;
352+
for (int i = 0; i < self.dim(); i++) {
353+
dims.sizes[i] = self.sizes()[i];
354+
}
355+
const int nthreads = 256;
356+
const int nblocks = (size + nthreads - 1) / nthreads;
357+
write_indices<<<nblocks, nthreads, 0, at::cuda::getCurrentCUDAStream()>>>(
358+
out_data_ptr,
359+
dims,
360+
self.dim(),
361+
size,
362+
(int64_t *)agg_cum.get() + grid_size - 1,
363+
fill_value);
364+
C10_CUDA_KERNEL_LAUNCH_CHECK();
365+
}
366+
if (need_to_copy) {
367+
out.copy_(out_temp);
368+
}
369+
#else
370+
TORCH_CHECK(false, "Nonzero_static is not supported for cuda <= 11.4");
371+
#endif
372+
}
373+
186374
Tensor& nonzero_out_cuda(const Tensor& self, Tensor& out) {
187375
TORCH_CHECK(
188376
out.dtype() == at::kLong,
@@ -216,4 +404,56 @@ Tensor nonzero_cuda(const Tensor& self) {
216404
Tensor out = at::detail::empty_cuda({0}, self.options().dtype(kLong));
217405
return at::native::nonzero_out_cuda(self, out);
218406
}
407+
408+
Tensor& nonzero_static_out_cuda(
409+
const Tensor& self,
410+
int64_t size,
411+
int64_t fill_value,
412+
Tensor& out) {
413+
TORCH_CHECK(
414+
out.dtype() == at::kLong,
415+
"nonzero_static: Expected out tensor to have scalar type ",
416+
at::kLong,
417+
" but got ",
418+
out.dtype());
419+
TORCH_CHECK(
420+
self.device() == out.device(),
421+
"expected self and out to be on the same device, but got out on ",
422+
out.device(),
423+
" and self on ",
424+
self.device());
425+
TORCH_CHECK(
426+
self.dim() <= MAX_DIMS,
427+
"nonzero_static is not supported for tensor with more than ",
428+
MAX_DIMS,
429+
" dimensions");
430+
TORCH_CHECK(
431+
size >= 0, "nonzero_static: 'size' must be an non-negative integer"
432+
)
433+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
434+
at::ScalarType::ComplexHalf,
435+
at::ScalarType::Bool,
436+
at::ScalarType::BFloat16,
437+
at::ScalarType::Half,
438+
self.scalar_type(),
439+
"nonzero_cuda",
440+
[&] {
441+
nonzero_static_cuda_out_impl<scalar_t>(self, size, fill_value, out);
442+
});
443+
return out;
444+
}
445+
446+
Tensor nonzero_static_cuda(
447+
const Tensor& self,
448+
int64_t size,
449+
int64_t fill_value) {
450+
TORCH_CHECK(
451+
size >= 0, "nonzero_static: 'size' must be an non-negative integer"
452+
)
453+
Tensor out = Tensor(at::detail::empty_cuda(
454+
{self.dim(), size}, self.options().dtype(kLong)))
455+
.t();
456+
return at::native::nonzero_static_out_cuda(self, size, fill_value, out);
457+
}
458+
219459
} // namespace at::native

0 commit comments

Comments
 (0)
0