8000 Update on "[IGNORE] Added add_scalar_" · pytorch/pytorch@cf29686 · GitHub
[go: up one dir, main page]

Skip to content

Commit cf29686

Browse files
author
Iurii Zdebskyi
committed
Update on "[IGNORE] Added add_scalar_"
**Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API **Tests** Tested via unit tests **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
1 parent 8e6256a commit cf29686

File tree

2 files changed

+23
-42
lines changed

2 files changed

+23
-42
lines changed

aten/src/ATen/native/cuda/ForeachTensorAddScalar.cu

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
#include <ATen/native/cuda/ForeachUtils.cuh>
33
#include <ATen/native/cuda/MultiTensorApply.cuh>
44

5-
// NOTE: CUDA on Windows requires that the enclosing function
6-
// of a __device__ lambda not have internal linkage.
7-
85
namespace at { namespace native {
96

107
namespace {

test/test_foreach.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ def test_add_scalar__same_size_tensors(self, device, dtype):
88
N = 20
99
H = 20
1010
W = 20
11-
tensors = []
12-
for _ in range(N):
13-
tensors.append(torch.zeros(H, W, device=device, dtype=dtype))
11+
tensors = [torch.zeros(H, W, device=device, dtype=dtype) for n in range(N)]
1412

1513
# bool tensor + 1 will result in int64 tensor
1614
if dtype == torch.bool:
@@ -26,9 +24,7 @@ def test_add_scalar_with_same_size_tensors(self, device, dtype):
2624
N = 20
2725
H = 20
2826
W = 20
29-
tensors = []
30-
for _ in range(N):
31-
tensors.append(torch.zeros(H, W, device=device, dtype=dtype))
27+
tensors = [torch.zeros(H, W, device=device, dtype=dtype) for n in range(N)]
3228

3329
res = torch._foreach_add(tensors, 1)
3430
for t in res:
@@ -43,27 +39,23 @@ def test_add_scalar_with_different_size_tensors(self, device, dtype):
4339
H = 20
4440
W = 20
4541

46-
tensors = []
47-
size_change = 0
48-
for _ in range(N):
49-
tensors.append(torch.zeros(H + size_change, W + size_change, device=device, dtype=dtype))
50-
size_change += 1
51-
42+
tensors = [torch.zeros(H + n, W + n, device=device, dtype=dtype) for n in range(N)]
5243
res = torch._foreach_add(tensors, 1)
5344

54-
size_change = 0
55-
for t in res:
56-
# bool tensor + 1 will result in int64 tensor
57-
if dtype == torch.bool:
58-
dtype = torch.int64
59-
self.assertEqual(t, torch.ones(H + size_change, W + size_change, device=device, dtype=dtype))
60-
size_change += 1
45+
# bool tensor + 1 will result in int64 tensor
46+
if dtype == torch.bool:
47+
dtype = torch.int64
48+
self.assertEqual([torch.ones(H + n, W + n, device=device, dtype=dtype) for n in range(N)], torch._foreach_add(tensors, 1))
6149

6250
@dtypes(*torch.testing.get_all_dtypes())
63-
def test_add_scalar_with_empty_list(self, device, dtype):
64-
tensors = []
65-
with self.assertRaises(RuntimeError):
66-
torch._foreach_add(tensors, 1)
51+
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
52+
# TODO: enable empty list case
53+
for tensors in [[torch.randn([0])]]:
54+
res = torch._foreach_add(tensors, 1)
55+
self.assertEqual(res, tensors)
56+
57+
torch._foreach_add_(tensors, 1)
58+
self.assertEqual(res, tensors)
6759

6860
@dtypes(*torch.testing.get_all_dtypes())
6961
def test_add_scalar_with_overlapping_tensors(self, device, dtype):
@@ -78,43 +70,35 @@ def test_add_scalar_with_overlapping_tensors(self, device, dtype):
7870
self.assertEqual(res, expected)
7971

8072
def test_add_scalar_with_different_tensor_dtypes(self, device):
81-
tensors = [torch.tensor([1], dtype=torch.float, device=device),
82-
torch.tensor([1], dtype=torch.int, device=device)]
73+
tensors = [torch.tensor([1.1], dtype=torch.float, device=device),
74+
torch.tensor([1], dtype=torch.long, device=device)]
8375

84-
expected = [torch.tensor([2], dtype=torch.float, device=device),
85-
torch.tensor([2], dtype=torch.int, device=device)]
76+
expected = [torch.tensor([2.1], dtype=torch.float, device=device),
77+
torch.tensor([2], dtype=torch.long, device=device)]
8678

8779
res = torch._foreach_add(tensors, 1)
8880
self.assertEqual(res, expected)
8981

9082
def test_add_scalar_with_different_scalar_type(self, device):
9183
# int tensor with float scalar
92-
# should go 'slow' route
9384
scalar = 1.1
9485
tensors = [torch.tensor([1], dtype=torch.int, device=device)]
95-
res = torch._foreach_add(tensors, scalar)
96-
self.assertEqual(res, [torch.tensor([2.1], device=device)])
86+
self.assertEqual([x + scalar for x in tensors], torch._foreach_add(tensors, scalar))
9787

9888
# float tensor with int scalar
99-
# should go 'fast' route
10089
scalar = 1
10190
tensors = [torch.tensor([1.1], device=device)]
102-
res = torch._foreach_add(tensors, scalar)
103-
self.assertEqual(res, [torch.tensor([2.1], device=device)])
91+
self.assertEqual([x + scalar for x in tensors], torch._foreach_add(tensors, scalar))
10492

10593
# bool tensor with int scalar
106-
# should go 'slow' route
10794
scalar = 1
10895
tensors = [torch.tensor([False], device=device)]
109-
res = torch._foreach_add(tensors, scalar)
110-
self.assertEqual(res, [torch.tensor([1], device=device)])
96+
self.assertEqual([x + scalar for x in tensors], torch._foreach_add(tensors, scalar))
11197

11298
# bool tensor with float scalar
113-
# should go 'slow' route
11499
scalar = 1.1
115100
tensors = [torch.tensor([False], device=device)]
116-
res = torch._foreach_add(tensors, scalar)
117-
self.assertEqual(res, [torch.tensor([1.1], device=device)])
101+
self.assertEqual([x + scalar for x in tensors], torch._foreach_add(tensors, scalar))
118102

119103
instantiate_device_type_tests(TestForeach, globals())
120104

0 commit comments

Comments
 (0)
0