-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🐛 Bug
Using torch.median(x)
is considerably slower than explicitly sorting the array with torch.sort()
and selecting the middle element when the tensor resides on cpu.
To Reproduce
Steps to reproduce the behavior:
>>> from torch.utils.benchmark import Timer
>>> t = Timer(stmt="torch.median(x, dim=0)", setup="x=torch.rand(11, 1000, 4000, device='cpu')")
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8e5527ba60>
torch.median(x, dim=0)
19.04 s
1 measurement, 1 runs , 1 thread
>>> t = Timer(stmt="torch.sort(x, dim=0).values[5]", setup="x=torch.rand(11, 1000, 4000, device='cpu')")
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8e584a7880>
torch.sort(x, dim=0).values[5]
856.96 ms
1 measurement, 1 runs , 1 thread
For CUDA tensors I did not observe this large difference.
>>> from torch.utils.benchmark import Timer
>>> t = Timer(stmt="torch.median(x, dim=0)", setup="x=torch.rand(11, 1000, 4000, device='cuda')")
<torch.utils.benchmark.utils.common.Measurement object at 0x7f85036f86d0>
torch.median(x, dim=0)
56.26 ms
1 measurement, 10 runs , 1 thread
>>> t = Timer(stmt="torch.sort(x, dim=0).values[5]", setup="x=torch.rand(11, 1000, 4000, device='cuda')")
<torch.utils.benchmark.utils.common.Measurement object at 0x7f85036192b0>
torch.sort(x, dim=0).values[5]
39.60 ms
1 measurement, 10 runs , 1 thread
Expected behavior
Since the median operation does not require to sort the whole tensor, it should be at least as fast as the sort operation.
Environment
Collecting environment information...
PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti
Nvidia driver version: 450.102.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect
Additional context
cc @VitalyFedyunin @ngimel @heitorschueroff @mruberry @rgommers
Metadata
Metadata
Assignees
Labels
Type
Projects
Status