8000 add max_and_min function and cpu kernel to speed up observers (#41570) · pytorch/pytorch@302e566 · GitHub
[go: up one dir, main page]

Skip to content

Commit 302e566

Browse files
vkuzofacebook-github-bot
authored andcommitted
add max_and_min function and cpu kernel to speed up observers (#41570)
Summary: Pull Request resolved: #41570 For min/max based quantization observers, calculating min and max of a tensor takes most of the runtime. Since the calculation of min and max is done on the same tensor, we can speed this up by only reading the tensor once, and reducing with two outputs. One question I had is whether we should put this into the quantization namespace, since the use case is pretty specific. This PR implements the easier CPU path to get an initial validation. There is some needed additional work in future PRs, which durumu will take a look at: * CUDA kernel and tests * making this work per channel * benchmarking on observer * benchmarking impact on QAT overhead Test Plan: ``` python test/test_torch.py TestTorch.test_min_and_max ``` quick bench (not representative of real world use case): https://gist.github.com/vkuzo/7fce61c3456dbc488d432430cafd6eca ``` (pytorch) [vasiliy@devgpu108.ash6 ~/local/pytorch] OMP_NUM_THREADS=1 python ~/nfs/pytorch_scripts/observer_bench.py tensor(5.0390) tensor(-5.4485) tensor([-5.4485, 5.0390]) min and max separate 11.90243935585022 min and max combined 6.353186368942261 % decrease 0.466228209277153 (pytorch) [vasiliy@devgpu108.ash6 ~/local/pytorch] OMP_NUM_THREADS=4 python ~/nfs/pytorch_scripts/observer_bench.py tensor(5.5586) tensor(-5.3983) tensor([-5.3983, 5.5586]) min and max separate 3.468616485595703 min and max combined 1.8227086067199707 % decrease 0.4745142294372342 (pytorch) [vasiliy@devgpu108.ash6 ~/local/pytorch] OMP_NUM_THREADS=8 python ~/nfs/pytorch_scripts/observer_bench.py tensor(5.2146) tensor(-5.2858) tensor([-5.2858, 5.2146]) min and max separate 1.5707778930664062 min and max combined 0.8645427227020264 % decrease 0.4496085496757899 ``` Imported from OSS Reviewed By: supriyar Differential Revision: D22589349 fbshipit-source-id: c2e3f1b8b5c75a23372eb6e4c885f842904528ed
1 parent 9e0c746 commit 302e566

File tree

6 files changed

+190
-16
lines changed

6 files changed

+190
-16
lines changed

aten/src/ATen/cpu/vec256/functional.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,35 @@ inline scalar_t reduce_all(const Op& vec_fun, scalar_t* data, int64_t size) {
4444
return vec_reduce_all(vec_fun, acc_vec, Vec::size());
4545
}
4646

47+
// similar to reduce_all, but reduces into two outputs
48+
template <typename scalar_t, typename Op1, typename Op2>
49+
inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
50+
scalar_t* data, int64_t size) {
51+
using Vec = vec256::Vec256<scalar_t>;
52+
if (size < Vec::size()) {
53+
auto loaded_data = Vec::loadu(data, size);
54+
return std::pair<scalar_t, scalar_t>(
55+
vec_reduce_all(vec_fun1, loaded_data, size),
56+
vec_reduce_all(vec_fun2, loaded_data, size));
57+
}
58+
int64_t d = Vec::size();
59+
Vec acc_vec1 = Vec::loadu(data);
60+
Vec acc_vec2 = Vec::loadu(data);
61+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
62+
Vec data_vec = Vec::loadu(data + d);
63+
acc_vec1 = vec_fun1(acc_vec1, data_vec);
64+
acc_vec2 = vec_fun2(acc_vec2, data_vec);
65+
}
66+
if (size - d > 0) {
67+
Vec data_vec = Vec::loadu(data + d, size - d);
68+
acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
69+
acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
70+
}
71+
return std::pair<scalar_t, scalar_t>(
72+
vec_reduce_all(vec_fun1, acc_vec1, Vec::size()),
73+
vec_reduce_all(vec_fun2, acc_vec2, Vec::size()));
74+
}
75+
4776
template <typename scalar_t, typename MapOp, typename ReduceOp>
4877
inline scalar_t map_reduce_all(
4978
const MapOp& map_fun,

aten/src/ATen/native/ReduceAllOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace native {
88

99
DEFINE_DISPATCH(min_all_stub);
1010
DEFINE_DISPATCH(max_all_stub);
11+
DEFINE_DISPATCH(_min_max_all_stub);
1112

1213
Tensor min(const Tensor &self) {
1314
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
@@ -25,4 +26,13 @@ Tensor max(const Tensor &self) {
2526
return result;
2627
}
2728

29+
std::tuple<Tensor, Tensor> _min_max(const Tensor &self) {
30+
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
31+
TORCH_CHECK(self.numel() > 0, "operation does not have an identity.");
32+
Tensor min_result = at::empty({}, self.options());
33+
Tensor max_result = at::empty({}, self.options());
34+
_min_max_all_stub(self.device().type(), min_result, max_result, self.contiguous());
35+
return std::tuple<Tensor&, Tensor&>(min_result, max_result);
36+
}
37+
2838
}} // namesapce at::native

aten/src/ATen/native/ReduceAllOps.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
namespace at { namespace native {
77

88
using reduce_all_fn = void (*)(Tensor & result, const Tensor & self);
9+
using reduce_min_max_fn = void (*)(Tensor & max_result, Tensor & min_result, const Tensor & self);
910
DECLARE_DISPATCH(reduce_all_fn, min_all_stub);
1011
DECLARE_DISPATCH(reduce_all_fn, max_all_stub);
12+
DECLARE_DISPATCH(reduce_min_max_fn, _min_max_all_stub);
1113

1214
}}

aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ inline void reduce_all_impl_vec(
2727
const int64_t input_numel = input.numel();
2828
auto input_data = input.data_ptr<scalar_t>();
2929
// NOTE: parallel_reduce not support bool type
30-
scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
30+
scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
3131
[&](int64_t start, int64_t end, const scalar_t ident) -> scalar_t {
3232
scalar_t partial_out = vec256::reduce_all<scalar_t>(
3333
[=](Vec x, Vec y) { return vop(x, y); },
@@ -47,7 +47,7 @@ inline void reduce_all_impl(
4747
func_t op) {
4848
const int64_t input_numel = input.numel();
4949
auto input_data = input.data_ptr<scalar_t>();
50-
scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
50+
scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
5151
[&](int64_t start, int64_t end, const scalar_t ident) -> scalar_t {
5252
scalar_t partial_out = ident;
5353
for (int64_t i = start; i < end; i++) {
@@ -108,9 +108,114 @@ static void max_all_kernel_impl(Tensor& result, const Tensor& input) {
108108
}
109109
}
110110

111+
// For operation not support in avx/avx2
112+
template <typename scalar_t, typename func_t1, typename func_t2>
113+
inline void reduce_all_impl_two_outputs(
114+
Tensor& output1,
115+
Tensor& output2,
116+
const Tensor& input,
117+
const std::pair<scalar_t, scalar_t>& ident_v,
118+
func_t1 reduce_chunk_func,
119+
func_t2 reduce_acc_func) {
120+
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
121+
const int64_t input_numel = input.numel();
122+
auto input_data = input.data_ptr<scalar_t>();
123+
scalar_t_pair result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
124+
[&](int64_t start, int64_t end, const scalar_t_pair& ident) -> scalar_t_pair {
125+
scalar_t_pair partial_out(ident);
126+
for (int64_t i = start; i < end; i++) {
127+
partial_out = reduce_chunk_func(partial_out, input_data[i]);
128+
}
129+
return partial_out;
130+
},
131+
reduce_acc_func
132+
);
133+
output1.fill_(result.first);
134+
output2.fill_(result.second);
135+
}
136+
137+
template <typename scalar_t, typename func_t, typename vec_func_t1, typename vec_func_t2>
138+
inline void reduce_all_impl_vec_two_outputs(
139+
Tensor& output1,
140+
Tensor& output2,
141+
const Tensor& input,
142+
const std::pair<scalar_t, scalar_t>& ident_v,
143+
func_t reduce_acc_func,
144+
vec_func_t1 reduce_chunk_func1,
145+
vec_func_t2 reduce_chunk_func2) {
146+
using Vec = Vec256<scalar_t>;
147+
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
148+
const int64_t input_numel = input.numel();
149+
auto input_data = input.data_ptr<scalar_t>();
150+
// NOTE: parallel_reduce not support bool type
151+
std::pair<scalar_t, scalar_t> result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
152+
[&](int64_t start, int64_t end, const scalar_t_pair& /* ident */) -> scalar_t_pair {
153+
scalar_t_pair partial_out = vec256::reduce2_all<scalar_t>(
154+
[=](Vec x, Vec y) { return reduce_chunk_func1(x, y); },
155+
[=](Vec x, Vec y) { return reduce_chunk_func2(x, y); },
156+
input_data + start,
157+
end - start);
158+
return partial_out;
159+
},
160+
reduce_acc_func
161+
);
162+
output1.fill_(result.first);
163+
output2.fill_(result.second);
164+
}
165+
166+
static void _min_max_all_kernel_impl(Tensor& min_result, Tensor& max_result,
167+
const Tensor& input) {
168+
if (input.scalar_type() == ScalarType::Bool) {
169+
TensorIterator iter = TensorIteratorConfig()
170+
.add_input(input)
171+
.build();
172+
bool min_result_data = true;
173+
bool max_result_data = false;
174+
cpu_serial_kernel(iter, [&](const bool a) -> void {
175+
min_result_data = min_result_data && a;
176+
max_result_data = max_result_data || a;
177+
});
178+
min_result.fill_(min_result_data);
179+
max_result.fill_(max_result_data);
180+
} else if (input.scalar_type() == ScalarType::Long) {
181+
// for int64_t, vectorized implementation have performance issue,
182+
// just use scalar path
183+
using int64_t_pair = std::pair<int64_t, int64_t>;
184+
reduce_all_impl_two_outputs<int64_t>(min_result, max_result, input,
185+
int64_t_pair(upper_bound<int64_t>(), lower_bound<int64_t>()),
186+
// reduce over chunk
187+
[=](int64_t_pair a, int64_t b) -> int64_t_pair {
188+
return int64_t_pair(min_impl(a.first, b), max_impl(a.second, b));
189+
},
190+
// combine two inputs
191+
[=](int64_t_pair a, int64_t_pair b) -> int64_t_pair {
192+
return int64_t_pair(min_impl(a.first, b.first), max_impl(a.second, b.second));
193+
}
194+
);
195+
} else {
196+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(input.scalar_type(), "_min_max_all", [&] {
197+
using Vec = vec256::Vec256<scalar_t>;
198+
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
199+
reduce_all_impl_vec_two_outputs<scalar_t>(
200+
min_result,
201+
max_result,
202+
input,
203+
scalar_t_pair(upper_bound<scalar_t>(), lower_bound<scalar_t>()),
204+
[=] (scalar_t_pair a , scalar_t_pair b) -> scalar_t_pair {
205+
return scalar_t_pair(
206+
min_impl(a.first, b.first), max_impl(a.second, b.second));
207+
},
208+
[=](Vec a, Vec b) -> Vec { return minimum(a, b); },
209+
[=](Vec a, Vec b) -> Vec { return maximum(a, b); }
210+
);
211+
});
212+
}
213+
}
214+
111215
} // namespace
112216

113217
REGISTER_DISPATCH(min_all_stub, &min_all_kernel_impl);
114218
REGISTER_DISPATCH(max_all_stub, &max_all_kernel_impl);
219+
REGISTER_DISPATCH(_min_max_all_stub, &_min_max_all_kernel_impl);
115220

116221
}}

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5076,6 +5076,13 @@
50765076
CPU, CUDA: max
50775077
QuantizedCPU: max_quant
50785078

5079+
# Return: (Tensor min, Tensor max)
5080+
- func: _min_max(Tensor self) -> (Tensor, Tensor)
5081+
use_c10_dispatcher: full
5082+
variants: function
5083+
dispatch:
5084+
CPU: _min_max
5085+
50795086
- func: median(Tensor self) -> Tensor
50805087
use_c10_dispatcher: full
50815088
variants: method, function

test/test_torch.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def test_has_storage(self):
275275
self.assertIsNotNone(torch.Tensor([0, 0, 0]).nonzero().storage())
276276
self.assertIsNotNone(torch.Tensor().new().storage())
277277

278-
def _testSelection(self, torchfn, mathfn):
278+
def _testSelection(self, torchfn, mathfn, skip_indices=False):
279279
# contiguous
280280
m1 = torch.randn(100, 100)
281281
res1 = torchfn(m1)
@@ -294,20 +294,21 @@ def _testSelection(self, torchfn, mathfn):
294294
self.assertEqual(res1, res2)
295295

296296
# with indices
297-
m1 = torch.randn(100, 100)
298-
res1val, res1ind = torchfn(m1, 1, False)
299-
res2val = m1[:, 0:1].clone().squeeze()
300-
res2ind = res1ind.clone().fill_(0)
301-
for i, j in iter_indices(m1):
302-
if mathfn(res2val[i], m1[i, j]) != res2val[i]:
303-
res2val[i] = m1[i, j]
304-
res2ind[i] = j
297+
if not skip_indices:
298+
m1 = torch.randn(100, 100)
299+
res1val, res1ind = torchfn(m1, 1, False)
300+
res2val = m1[:, 0:1].clone().squeeze()
301+
res2ind = res1ind.clone().fill_(0)
302+
for i, j in iter_indices(m1):
303+
if mathfn(res2val[i], m1[i, j]) != res2val[i]:
304+
res2val[i] = m1[i, j]
305+
res2ind[i] = j
305306

306-
maxerr = 0
307-
for i in range(res1val.size(0)):
308-
maxerr = max(maxerr, abs(res1val[i] - res2val[i]))
309-
self.assertEqual(res1ind[i], res2ind[i])
310-
self.assertLessEqual(abs(maxerr), 1e-5)
307+
maxerr = 0
308+
for i in range(res1val.size(0)):
309+
maxerr = max(maxerr, abs(res1val[i] - res2val[i]))
310+
self.assertEqual(res1ind[i], res2ind[i])
311+
self.assertLessEqual(abs(maxerr), 1e-5)
311312

312313
# NaNs
313314
for index in (0, 4, 99):
@@ -327,12 +328,32 @@ def _testSelection(self, torchfn, mathfn):
327328
res2 = mathfn(res2, m1[i])
328329
self.assertEqual(res1, res2)
329330

331+
# Long
332+
m1 = torch.LongTensor(100).random_(-1000, 1000)
333+
res1 = torchfn(m1)
334+
res2 = m1[0]
335+
for i in iter_indices(m1):
336+
res2 = mathfn(res2, m1[i])
337+
self.assertEqual(res1, res2)
338+
339+
330340
def test_max(self):
331341
self._testSelection(torch.max, max)
332342

333343
def test_min(self):
334344
self._testSelection(torch.min, min)
335345

346+
def test_min_max(self):
347+
# TODO: implement indices, in a future PR
348+
# min correctness
349+
self._testSelection(lambda x: torch._min_max(x)[0],
350+
lambda x, y: min(x, y),
351+
skip_indices=True)
352+
# max correctness
353+
self._testSelection(lambda x: torch._min_max(x)[1],
354+
lambda x, y: max(x, y),
355+
skip_indices=True)
356+
336357
def test_dim_reduction_uint8_overflow(self):
337358
example = [[-1, 2, 1], [5, 3, 6]]
338359
x = torch.tensor(example, dtype=torch.uint8)

0 commit comments

Comments
 (0)
0