8000 Add private API to support tensor lists: _foreach_add(TensorList tens… · pytorch/pytorch@e995c3d · GitHub
[go: up one dir, main page]

Skip to content

Commit e995c3d

Browse files
8000 izdebyfacebook-github-bot
authored andcommitted
Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar) (#41554)
Summary: Initial PR for the Tensor List functionality. **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **In this PR** - Adding `multi_tensor_apply` mechanism which will help to efficiently apply passed functor on a given list of tensors on CUDA. - Adding a first private API - `std::vector<Tensor> _foreach_add(TensorList tensors, Scalar scalar)` **Tests** Tested via unit tests **Plan for the next PRs** 1. Cover these ops with `multi_tensor_apply` support - exponent - division - mul_ - add_ - addcmul_ - addcdiv_ - Sqrt 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. Pull Request resolved: #41554 Reviewed By: cpuhrsch Differential Revision: D22829724 Pulled By: izdeby fbshipit-source-id: 47febdbf7845cf931958a638567b7428a24782b1
1 parent a0695b3 commit e995c3d

File tree

7 files changed

+365
-0
lines changed

7 files changed

+365
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <ATen/ATen.h>
2+
namespace at { namespace native {
3+
4+
std::vector<Tensor> foreach_add_scalar_kernel_fallback(TensorList tensors, Scalar scalar) {
5+
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
6+
7+
std::vector<Tensor> result;
8+
for (int i = 0; i < tensors.size(); i++) {
9+
auto temp = tensors[i].add(scalar);
10+
result.emplace_back(temp);
11+
}
12+
return result;
13+
}
14+
}} // namespace at::native
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/cuda/ForeachUtils.cuh>
3+
#include <ATen/native/cuda/MultiTensorApply.cuh>
4+
5+
// NOTE: CUDA on Windows requires that the enclosing function
6+
// of a __device__ lambda not have internal linkage.
7+
8+
namespace at { namespace native {
9+
10+
namespace {
11+
12+
template<typename x_t, typename out_t>
13+
struct AddScalarFunctor {
14+
__device__ void operator() (
15+
int chunk_size,
16+
TensorListMetadata<2>& tl,
17+
x_t scalar) {
18+
int tensor_loc = tl.block_to_tensor[blockIdx.x];
19+
int chunk_idx = tl.block_to_chunk[blockIdx.x];
20+
int n = tl.sizes[tensor_loc];
21+
22+
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
23+
x += chunk_idx * chunk_size;
24+
25+
out_t* out = (out_t*)tl.addresses[1][tensor_loc];
26+
out += chunk_idx * chunk_size;
27+
28+
n -= chunk_idx * chunk_size;
29+
30+
x_t r_x[kILP];
31+
out_t r_out[kILP];
32+
33+
// to make things simple, we put aligned case in a different code path
34+
if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) {
35+
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
36+
// load
37+
load_store(r_x, x, 0 , i_start);
38+
#pragma unroll
39+
for(int ii = 0; ii < kILP; ii++) {
40+
r_out[ii] = static_cast<x_t>(r_x[ii]) + scalar;
41+
}
42+
// store
43+
load_store(out, r_out, i_start, 0);
44+
}
45+
}
46+
else {
47+
// Non-divergent exit condition for __syncthreads, not necessary here
48+
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
49+
#pragma unroll
50+
for(int ii = 0; ii < kILP; ii++) {
51+
r_x[ii] = 0;
52+
int i = i_start + threadIdx.x + ii * blockDim.x;
53+
if(i < n && i < chunk_size) {
54+
r_x[ii] = x[i];
55+
}
56+
}
57+
#pragma unroll
58+
for(int ii = 0; ii < kILP; ii++) {
59+
r_out[ii] = static_cast<x_t>(r_x[ii]) + scalar;
60+
}
61+
#pragma unroll
62+
for(int ii = 0; ii < kILP; ii++) {
63+
int i = i_start + threadIdx.x + ii * blockDim.x;
64+
if(i < n && i < chunk_size)
65+
out[i] = r_out[ii];
66+
}
67+
}
68+
}
69+
}
70+
};
71+
72+
} // namespace
73+
74+
std::vector<Tensor> foreach_tensor_add_scalar_kernel_cuda(TensorList tensors, Scalar scalar) {
75+
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
76+
77+
if (!check_fast_route(tensors, scalar)) {
78+
return at::native::foreach_add_scalar_kernel_fallback(tensors, scalar);
79+
}
80+
81+
std::vector<std::vector<at::Tensor>> tensor_lists;
82+
std::vector<at::Tensor> vec_res;
83+
for (const auto& t: tensors) {
84+
vec_res.emplace_back(at::native::empty_like(t));
85+
}
86+
87+
tensor_lists.emplace_back(std::move(tensors.vec()));
88+
tensor_lists.emplace_back(std::move(vec_res));
89+
90+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_tensor_add_scalar_kernel_cuda", [&]() {
91+
multi_tensor_apply<2>(tensor_lists, AddScalarFunctor<scalar_t, scalar_t>(), scalar.to<scalar_t>());
92+
});
93+
return tensor_lists[1];
94+
}
95+
96+
}} // namespace at::native
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#pragma once
2+
#include <ATen/ATen.h>
3+
#include <ATen/native/cuda/Loops.cuh>
4+
#include <ATen/native/cuda/MemoryAccess.cuh>
5+
namespace at {
6+
namespace native {
7+
namespace {
8+
9+
static constexpr int64_t kILP = 4;
10+
static constexpr int64_t kChunkSize = 65536;
11+
static constexpr int64_t kBlockSize = 512;
12+
13+
template<typename T>
14+
__device__ __forceinline__ bool is_aligned(T* p){
15+
return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
16+
}
17+
18+
template<typename T>
19+
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
20+
using LT = at::native::memory::aligned_vector<T, kILP>;
21+
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
22+
}
23+
24+
}
25+
26+
bool check_fast_route(TensorList tensors, Scalar scalar) {
27+
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
28+
auto expected_dtype = tensors[0].dtype();
29+
auto expected_device = tensors[0].device();
30+
31+
for (auto t : tensors) {
32+
if (t.dtype() != expected_dtype) {
33+
return false;
34+
}
35+
36+
if (t.device() != expected_device) {
37+
return false;
38+
}
39+
40+
if (t.layout() != at::kStrided) {
41+
return false;
42+
}
43+
44+
if (!t.is_non_overlapping_and_dense()) {
45+
return false;
46+
}
47+
48+
if ((at::isIntegralType(t.scalar_type(), true) && scalar.isFloatingPoint()) ||
49+
t.scalar_type() == at::kBool) {
50+
return false;
51+
}
52+
}
53+
54+
return true;
55+
}
56+
}} // at::native
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/cuda/CUDAContext.h>
3+
#include <ATen/native/cuda/ForeachUtils.cuh>
4+
#include <c10/cuda/CUDAGuard.h>
5+
6+
namespace at { namespace native {
7+
8+
namespace {
9+
10+
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
11+
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
12+
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
13+
14+
template<int n> struct TensorListMetadata
15+
{
16+
void* addresses[n][depth_to_max_tensors[n-1]];
17+
int sizes[depth_to_max_tensors[n-1]];
18+
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
19+
int block_to_chunk[depth_to_max_blocks[n-1]];
20+
};
21+
22+
template<typename T, typename U, typename... ArgTypes>
23+
C10_LAUNCH_BOUNDS_1(kBlockSize)
24+
__global__ void
25+
multi_tensor_apply_kernel(
26+
T tensorListMeta,
27+
U callable,
28+
ArgTypes... args) {
29+
// Hand the chunk information to the user-supplied functor to process however it likes.
30+
callable(kChunkSize, tensorListMeta, args...);
31+
}
32+
33+
template<int depth, typename T, typename... ArgTypes>
34+
void multi_tensor_apply(
35+
std::vector<std::vector<at::Tensor>>& tensor_lists,
36+
T callable,
37+
ArgTypes... args) {
38+
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth.");
39+
const cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
40+
41+
size_t n_tensors = tensor_lists[0].size();
42+
TensorListMetadata<depth> tensorListMeta;
43+
44+
int loc_block_info = 0;
45+
int loc_tensor_info = 0;
46+
for(size_t t = 0; t < n_tensors; t++) {
47+
tensorListMeta.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
48+
for (int d = 0; d < depth; d++) {
49+
tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
50+
}
51+
loc_tensor_info++;
52+
53+
int chunks = (tensor_lists[0][t].numel() + kChunkSize - 1)/kChunkSize;
54+
for (int chunk = 0; chunk < chunks; chunk++) {
55+
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
56+
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
57+
loc_block_info++;
58+
59+
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
60+
chunk == chunks - 1);
61+
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
62+
bool last_chunk = (t == n_tensors - 1 && chunk == chunks - 1);
63+
64+
if (tensors_full || blocks_full || last_chunk) {
65+
multi_tensor_apply_kernel<<<loc_block_info, kBlockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
66+
tensorListMeta,
67+
callable,
68+
args...);
69+
70+
AT_CUDA_CHECK(cudaGetLastError());
71+
72+
// Reset.
73+
loc_block_info = 0;
74+
if(chunk == chunks - 1) {
75+
loc_tensor_info = 0;
76+
}
77+
else {
78+
tensorListMeta.sizes[0] = tensorListMeta.sizes[loc_tensor_info-1];
79+
for(int d = 0; d < depth; d++) {
80+
tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1];
81+
}
82+
loc_tensor_info = 1;
83+
}
84+
}
85+
}
86+
}
87+
}
88+
} // namespace
89+
}} // at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5410,6 +5410,13 @@
54105410
CUDA: cat_out_cuda
54115411
QuantizedCPU: cat_out_quantized_cpu
54125412

5413+
- func: _foreach_add.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
5414+
device_guard: False
5415+
variants: function
5416+
dispatch:
5417+
CPU: foreach_add_scalar_kernel_fallback
5418+
CUDA: foreach_tensor_add_scalar_kernel_cuda
5419+
54135420
- func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor)
54145421
use_c10_dispatcher: full
54155422
dispatch:

test/run_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
'distributed/test_distributed',
3939
'test_distributions',
4040
'test_expecttest',
41+
'test_foreach',
4142
'test_indexing',
4243
'test_jit',
4344
'test_logging',

test/test_foreach.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ 10000 -0,0 +1,102 @@
1+
import torch
2+
import torch.cuda
3+
from torch.testing._internal.common_utils import TestCase, run_tests
4+
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
5+
6+
class TestForeach(TestCase):
7+
@dtypes(*torch.testing.get_all_dtypes())
8+
def test_add_scalar_with_same_size_tensors(self, device, dtype):
9+
N = 20
10+
H = 20
11+
W = 20
12+
tensors = []
13+
for _ in range(N):
14+
tensors.append(torch.zeros(H, W, device=device, dtype=dtype))
15+
16+
res = torch._foreach_add(tensors, 1)
17+
for t in res:
18+
if dtype == torch.bool:
19+
dtype = torch.int64
20+
self.assertEqual(t, torch.ones(H, W, device=device, dtype=dtype))
21+
22+
@dtypes(*torch.testing.get_all_dtypes())
23+
def test_add_scalar_with_different_size_tensors(self, device, dtype):
24+
N = 20
25+
H = 20
26+
W = 20
27+
28+
tensors = []
29+
size_change = 0
30+
for _ in range(N):
31+
tensors.append(torch.zeros(H + size_change, W + size_change, device=device, dtype=dtype))
32+
size_change += 1
33+
34+
res = torch._foreach_add(tensors, 1)
35+
36+
size_change = 0
37+
for t in res:
38+
if dtype == torch.bool:
39+
dtype = torch.int64
40+
self.assertEqual(t, torch.ones(H + size_change, W + size_change, device=device, dtype=dtype))
41+
size_change += 1
42+
43+
@dtypes(*torch.testing.get_all_dtypes())
44+
def test_add_scalar_with_empty_list(self, device, dtype):
45+
tensors = []
46+
with self.assertRaises(RuntimeError):
47+
torch._foreach_add(tensors, 1)
48+
49+
@dtypes(*torch.testing.get_all_dtypes())
50+
def test_add_scalar_with_overlapping_tensors(self, device, dtype):
51+
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
52+
expected = [torch.tensor([[[2, 2, 2]], [[2, 2, 2]]], dtype=dtype, device=device)]
53+
54+
if dtype == torch.bool:
55+
expected[0] = expected[0].to(torch.int64).add(1)
56+
57+
res = torch._foreach_add(tensors, 1)
58+
self.assertEqual(res, expected)
59+
60+
def test_add_scalar_with_different_tensor_dtypes(self, device):
61+
tensors = [torch.tensor([1], dtype=torch.float, device=device),
62+
torch.tensor([1], dtype=torch.int, device=device)]
63+
64+
expected = [torch.tensor([2], dtype=torch.float, device=device),
65+
torch.tensor([2], dtype=torch.int, device=device)]
66+
67+
res = torch._foreach_add(tensors, 1)
68+
self.assertEqual(res, expected)
69+
70+
def test_add_scalar_with_different_scalar_type(self, device):
71+
# int tensor with float scalar
72+
# should go 'slow' route
73+
scalar = 1.1
74+
tensors = [torch.tensor([1], dtype=torch.int, device=device)]
75+
res = torch._foreach_add(tensors, scalar)
76+
self.assertEqual(res, [torch.tensor([2.1], device=device)])
77+
78+
# float tensor with int scalar
79+
# should go 'fast' route
80+
scalar = 1
81+
tensors = [torch.tensor([1.1], device=device)]
82+
res = torch._foreach_add(tensors, scalar)
83+
self.assertEqual(res, [torch.tensor([2.1], device=device)])
84+
85+
96D1 # bool tensor with int scalar
86+
# should go 'slow' route
87+
scalar = 1
88+
tensors = [torch.tensor([False], device=device)]
89+
res = torch._foreach_add(tensors, scalar)
90+
self.assertEqual(res, [torch.tensor([1], device=device)])
91+
92+
# bool tensor with float scalar
93+
# should go 'slow' route
94+
scalar = 1.1
95+
tensors = [torch.tensor([False], device=device)]
96+
res = torch._foreach_add(tensors, scalar)
97+
self.assertEqual(res, [torch.tensor([1.1], device=device)])
98+
99+
instantiate_device_type_tests(TestForeach, globals())
100+
101+
if __name__ == '__main__':
102+
run_tests()

0 commit comments

Comments
 (0)
0