-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🐛 Bug
Reduce operations that depend on binary_kernel_reduce suffers from huge slowdowns (Upto 100x in most cases) whenever the dimension over which these operations are applied is small compared to the other dimensions. These include torch.norm
/ torch.linalg.norm
, torch.std
, torch.var
, torch.argmax
, torch.argmin
.
Apart from the above ones, torch.mean
, torch.sum
, torch.max
, torch.min
suffer from less severe performance losses. The code snippet below illustrates the performance gaps:
In [34]: inp = torch.rand(10**6, 2) # Similar observations for higher dimensions.
In [35]: %timeit torch.sum(inp)
210 µs ± 2.86 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [36]: %timeit torch.sum(inp, dim=0)
998 µs ± 28.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [37]: %timeit torch.sum(inp, dim=1)
4.78 ms ± 61.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [43]: %timeit torch.max(inp)
306 µs ± 932 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [44]: %timeit torch.max(inp, dim=0)
1.58 ms ± 40.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [45]: %timeit torch.max(inp, dim=1)
4.08 ms ± 6.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [47]: %timeit torch.norm(inp)
1.46 ms ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [47]: %timeit torch.norm(inp, dim=0)
1.59 ms ± 9.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [48]: %timeit torch.norm(inp, dim=1)
146 ms ± 415 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [50]: %timeit torch.mean(inp)
248 µs ± 1.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [51]: %timeit torch.mean(inp, dim=0)
1.07 ms ± 7.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [52] %timeit torch.mean(inp, dim=1)
4.87 ms ± 25.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
I have been studying the implementations and been looking for the bottlenecks. It turns out that the TensorIteratorBase::foreach_reduced_elt
pytorch/aten/src/ATen/native/TensorIteratorReduce.cpp
Lines 121 to 172 in 93973ee
void TensorIteratorBase::foreach_reduced_elt(loop_subiter_t loop, bool parallelize) { | |
AT_ASSERT(ninputs() == 1); | |
AT_ASSERT(noutputs() >= 1); | |
auto shape = this->shape(); | |
if (output(0).numel() == 0) { | |
return; | |
} | |
if (output(0).numel() == 1) { | |
loop(*this); | |
} | |
else if (numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 || | |
at::in_parallel_region() || !parallelize) { | |
auto reduce_dims = num_reduce_dims(); | |
auto non_reduced_shape = shape.slice(reduce_dims, shape.size() - reduce_dims); | |
int64_t non_reduced_numel = 1; | |
for (int i = 0; i < non_reduced_shape.size(); ++i) { | |
non_reduced_numel *= non_reduced_shape[i]; | |
} | |
DimCounter dims {non_reduced_shape, {0, non_reduced_numel}}; | |
while (!dims.is_done()) { | |
TensorIterator reduced = *this; | |
reduced.select_all_keeping_dim(reduce_dims, dims.values); | |
loop(reduced); | |
dims.increment({1, 1}); | |
} | |
} | |
else { | |
int dim = find_split_dim(*this); | |
int64_t cols = shape[dim]; | |
at::parallel_for(0, cols, 1, [&](int64_t begin, int64_t end) { | |
if (begin == end) { | |
return; | |
} | |
TensorIterator sub_iter(*this); | |
sub_iter.narrow(dim, begin, end - begin); | |
// On some broken setups, `#ifdef _OPENMP` is true, | |
// and `get_num_threads` returns > 1, but | |
// `#pragma omp parallel` is ignored. | |
// There is no API to check for this, so we need to explicitly | |
// stop trying to parallelize if we've already gotten here. | |
// | |
// (If we are on one of those broken setups, we will | |
// only have one thread here, and end - begin == cols.) | |
sub_iter.foreach_reduced_elt(loop, false); | |
}); | |
} | |
} |
binary_kernel_reduce
relies upon is the bottleneck which is mainly due to the following while loop : pytorch/aten/src/ATen/native/TensorIteratorReduce.cpp
Lines 143 to 148 in 93973ee
while (!dims.is_done()) { | |
TensorIterator reduced = *this; | |
reduced.select_all_keeping_dim(reduce_dims, dims.values); | |
loop(reduced); | |
dims.increment({1, 1}); | |
} |
inp = torch.rand(10**6, 2)
torch.norm(inp, dim=1)
If N threads are spawned, then each thread iterates (10**6) / N times in the above loop which is not great given that the operations performed inside the loop are not lightweight! I have done some profiling here. There should be a faster approach which is rather raw instead of using DimCounter
within TensorIteratorBase::foreach_reduced_elt
. That method will likely need redesigning. Similar observations were found regarding torch.norm. I will be happy to work on this and can send a PR.
6B9A
p>
Environment
PyTorch version: 1.8.0a0+508bab4
Is debug build: False
CUDA used to build PyTorch: Only CPU
ROCM used to build PyTorch: N/A
OS: Linux Mint 20 (x86_64)
GCC version: (Ubuntu 10.2.0-5ubuntu1~20.04) 10.2.0
Clang version: Could not collect
CMake version: version 3.16.3
Python version: 3.7 (64-bit runtime)
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.21.0.dev0+430.ge745a19cb
[pip3] torch==1.8.0a0+unknown
[conda] Could not collect
cc @VitalyFedyunin @ngimel