10000 nits on "add max_and_min function and cpu kernel to speed up observers" · pytorch/pytorch@4944b3f · GitHub
[go: up one dir, main page]

Skip to content

Commit 4944b3f

Browse files
committed
nits on "add max_and_min function and cpu kernel to speed up observers"
Summary: For min/max based quantization observers, calculating min and max of a tensor takes most of the runtime. Since the calculation of min and max is done on the same tensor, we can speed this up by only reading the tensor once, and reducing with two outputs. One question I had is whether we should put this into the quantization namespace, since the use case is pretty specific. This PR implements the easier CPU path to get an initial validation. There is some needed additional work in future PRs, which @jpgraham will take a look at: * CUDA kernel and tests * making this work per channel * benchmarking on observer * benchmarking impact on QAT overhead Test Plan: ``` python test/test_torch.py TestTorch.test_min_and_max ``` quick bench (not representative of real world use case): https://gist.github.com/vkuzo/7fce61c3456dbc488d432430cafd6eca ``` (pytorch) [vasiliy@devgpu108.ash6 ~/local/pytorch] OMP_NUM_THREADS=1 python ~/nfs/pytorch_scripts/observer_bench.py tensor(5.0390) tensor(-5.4485) tensor([-5.4485, 5.0390]) min and max separate 11.90243935585022 min and max combined 6.353186368942261 % decrease 0.466228209277153 (pytorch) [vasiliy@devgpu108.ash6 ~/local/pytorch] OMP_NUM_THREADS=4 python ~/nfs/pytorch_scripts/observer_bench.py tensor(5.5586) tensor(-5.3983) tensor([-5.3983, 5.5586]) min and max separate 3.468616485595703 min and max combined 1.8227086067199707 % decrease 0.4745142294372342 (pytorch) [vasiliy@devgpu108.ash6 ~/local/pytorch] OMP_NUM_THREADS=8 python ~/nfs/pytorch_scripts/observer_bench.py tensor(5.2146) tensor(-5.2858) tensor([-5.2858, 5.2146]) min and max separate 1.5707778930664062 min and max combined 0.8645427227020264 % decrease 0.4496085496757899 ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D22589349](https://our.internmc.facebook.com/intern/diff/D22589349) [ghstack-poisoned]
2 parents 8e793df + 6e72f9f commit 4944b3f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,14 @@ inline void reduce_all_impl_two_outputs(
114114
Tensor& output1,
115115
Tensor& output2,
116116
const Tensor& input,
117-
const std::pair<scalar_t, scalar_t> ident_v,
117+
const std::pair<scalar_t, scalar_t>& ident_v,
118118
func_t1 reduce_chunk_func,
119119
func_t2 reduce_acc_func) {
120120
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
121121
const int64_t input_numel = input.numel();
122122
auto input_data = input.data_ptr<scalar_t>();
123123
scalar_t_pair result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
124-
[&](int64_t start, int64_t end, const scalar_t_pair ident) -> scalar_t_pair {
124+
[&](int64_t start, int64_t end, const scalar_t_pair& ident) -> scalar_t_pair {
125125
scalar_t_pair partial_out(ident);
126126
for (int64_t i = start; i < end; i++) {
127127
partial_out = reduce_chunk_func(partial_out, input_data[i]);
@@ -139,7 +139,7 @@ inline void reduce_all_impl_vec_two_outputs(
139139
Tensor& output1,
140140
Tensor& output2,
141141
const Tensor& input,
142-
const std::pair<scalar_t, scalar_t> ident_v,
142+
const std::pair<scalar_t, scalar_t>& ident_v,
143143
func_t reduce_acc_func,
144144
vec_func_t1 reduce_chunk_func1,
145145
vec_func_t2 reduce_chunk_func2) {
@@ -149,7 +149,7 @@ inline void reduce_all_impl_vec_two_outputs(
149149
auto input_data = input.data_ptr<scalar_t>();
150150
// NOTE: parallel_reduce not support bool type
151151
std::pair<scalar_t, scalar_t> result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
152-
[&](int64_t start, int64_t end, const scalar_t_pair ident) -> scalar_t_pair {
152+
[&](int64_t start, int64_t end, const scalar_t_pair& /* ident */) -> scalar_t_pair {
153153
scalar_t_pair partial_out = vec256::reduce2_all<scalar_t>(
154154
[=](Vec x, Vec y) { return reduce_chunk_func1(x, y); },
155155
[=](Vec x, Vec y) { return reduce_chunk_func2(x, y); },

0 commit comments

Comments
 (0)
0