-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🐛 Describe the bug
(I'll add actual benchmarking details and logs and output_code.py in a bit)
I'm doing min_sum and mul_sum in two setups:
- (D, ) x (D, ) -> scalar
- (B, N, 1, D) x (B, 1, N, D) -> (B, N, N)
Case (1) is similar to computing some sort of norm of a given input or a distance between two inputs.
Case (2) is matmul-like and is similar to computing distances between all pairs of the batch (mm / pdist / cdist).
When running as python3 bench_minmul_sum.py --enable-dynamic --avx512 --verbose1
Findings and questions regarding dynamic shapes:
- Dynamic / static shape options are not confined to the torch.compile call (which is super-unintuitive and brittle). When using
--enable-dynamic
, all producedoutput_code.py
contain dynamic shapes (meaning, dynamic shapes ended up using not only formin_sum_compiled_dynamic
but also formin_sum_compiled
) - When not using
--enable-dynamic
, all producedoutput_code.py
contain static shapes - ideally this may lead to full loop unrolling for small dims, but currently loop unrolling may only be done by the underlying C++ compiler, loop unrolling pragmas are currently missing from the generated C++ - Dynamic shapes
output_code.py
do not record divisibility and always contain an extra tail loop. - Ideally, I would like to have static shapes for 1d ops (inner dim is static or there are only a few possible values in my usecase); static shapes for the inner dim of the 2d ops and dynamic shapes for the batch dims (like can be done for torch.onnx.export). Currently parametrizing static or dynamic shapes is unpredictable :( This is not very good :(
- Why is
#pragma omp simd simdlen(8)
used? Apparently it's correct, but if it's the register-size-in-float32, I would expect it to besimdlen(16)
, no? Also, I had hard time to find out the resulting assembly and whether the compiled did loop unrolling (there was no#pragma omp unroll
). How can one do that? Could the debug option print the produced assembly too? (obtained fromobjdump
). Currently there are also too much logs, hard to get through all of them. - Better explaining of the tiling story to the users is important - for matmul-like ops especially. And these custom matmul-like ops are quite common in kernel methods and similar (see pykeops): max-sum operation #59545 [Feature Request] Compile compatible Neighborhood Algorithms for large Tensors #97006 Direct Implementation of K-Nearest neighbor (KNN) in pytorch #71386
Findings and questions regarding NaNs:
- min_sum produced codes have NaN handling
- It can be good to give an asserting or a hint to the compiler that the inputs will not contain NaNs, so being NaN-compliant is not needed. This NaN handling might be a perf hit for small 1d vectors.
Findings regarding mul_sum (which is equivalent to matrix product or dot product):
- The mul_sum pattern is not recognized and
gemm
call ordot
call is not produced. Seems that no tiling is done despite the fact that sum-reduction is used.
Findings regarding the benchmarking:
- On some Linux platforms (e.g. Windows WSLv1) it can be hard to account for CPU throttling. It would be good if PyTorch had some recommendations on how to account for it properly. And on laptops on battery there is often quite severe CPU throttling.
- config.cpp.simdlen is 512 despite that without
ATEN_CPU_CAPABILITY=avx512
, config.show() showsCPU capability usage: AVX2
(which 256-bits registers only). And it seems that this ATEN_CPU_CAPABILITY of only 256-bits is discovered automatically, despite that the laptop supports avx512
Misc:
- NVidia Triton will soon merge more direct and discoverable PyTorch + compile support: Add PyTorch platform handler triton-inference-server/python_backend#282, they are even considering of making torch.compile enabled as default option: Add PyTorch platform handler triton-inference-server/python_backend#282 (comment). This means that better telemetry and predictability about torch.compile would soon be more important. E.g. being able to completely save the best benchmarked kernels / selected cudnn algos to some file/database and then providing them as is when deploying to new servers. This is important to force the same algos/impls at every startup (I assume currently the selected cudnn algos might change if at algo benchmarking time, someone at a shared server e.g. occupies the needed benchmarking memory) and giving some indication / hooks when it was not possible.
- Better compiler debug output visualization report is needed similar to https://godbolt.org for C++ -> Assembly. Maybe some HTML report containing the Python source code so that one can click on compiled function and have C++ / CUDA / Triton / Assembly shown for inspection?
# bench_minmul_sum.py
import os
import sys
import time
enable_dynamic = '--enable-dynamic' in sys.argv
if '--avx512' in sys.argv:
os.environ['ATEN_CPU_CAPABILITY'] = 'avx512'
if '--verbose1' in sys.argv:
os.environ['TORCH_COMPILE_DEBUG'] = '1'
if '--verbose2' in sys.argv:
os.environ['TORCH_LOGS'] = 'output_code'
if '--simd512':
from torch._inductor import config
config.cpp.simdlen = 512
from torch._inductor import config
print('config.cpp.simdlen', config.cpp.simdlen)
print([line for line in torch.__config__.show().splitlines() if 'CPU capability' in line])
import torch
min_sum = lambda a, b: torch.min(a, b).sum(-1)
mul_sum = lambda a, b: torch.mul(a, b).sum(-1)
min_sum_compiled_dynamic = torch.compile(min_sum, dynamic = enable_dynamic)
mul_sum_compiled_dynamic = torch.compile(mul_sum, dynamic = enable_dynamic)
min_sum_compiled = torch.compile(min_sum)
mul_sum_compiled = torch.compile(mul_sum)
# 192 is 12 float32x16 (512-bit registers) or 24 float32x8 (256-bit registers)
static_shape = (192,)
dynamic_shape = (6, 183, 192)
def benchmark(name, f, a, b, K = 100):
tic = time.time()
for k in range(K):
f(a, b)
print(name, a.shape, b.shape, (time.time() - tic) * 1000, 'ms')
# warmup
for k in range(5):
A = torch.rand(*dynamic_shape, dtype = torch.float32).unsqueeze(2)
B = torch.rand(*dynamic_shape, dtype = torch.float32).unsqueeze(1)
a = torch.rand(*static_shape, dtype = torch.float32)
b = torch.rand(*static_shape, dtype = torch.float32)
if enable_dynamic:
min_sum_compiled_dynamic(A, B)
mul_sum_compiled_dynamic(A, B)
mul_sum_compiled_dynamic(A, B)
min_sum_compiled_dynamic(A, A)
min_sum_compiled(A, B)
mul_sum_compiled(A, B)
min_sum_compiled(A, A)
mul_sum_compiled(A, A)
min_sum(A, B)
mul_sum(A, B)
min_sum(A, A)
mul_sum(A, A)
if enable_dynamic:
min_sum_compiled_dynamic(a, b)
mul_sum_compiled_dynamic(a, b)
min_sum_compiled_dynamic(a, a)
mul_sum_compiled_dynamic(a, a)
min_sum_compiled(a, b)
mul_sum_compiled(a, b)
min_sum_compiled(a, a)
mul_sum_compiled(a, a)
min_sum(a, b)
mul_sum(a, b)
min_sum(a, a)
mul_sum(a, a)
# benchmark
A = torch.rand(*dynamic_shape, dtype = torch.float32).unsqueeze(2)
B = torch.rand(*dynamic_shape, dtype = torch.float32).unsqueeze(1)
a = torch.rand(*static_shape, dtype = torch.float32)
b = torch.rand(*static_shape, dtype = torch.float32)
if enable_dynamic:
benchmark('min_sum_compiled_dynamic_AB', min_sum_compiled_dynamic, A, B)
benchmark('mul_sum_compiled_dynamic_ab', mul_sum_compiled_dynamic, a, b)
benchmark('min_sum_compiled_dynamic_AA', min_sum_compiled_dynamic, A, A)
benchmark('mul_sum_compiled_dynamic_aa', mul_sum_compiled_dynamic, a, a)
benchmark('min_sum_compiled_AB', min_sum_compiled, A, B)
benchmark('mul_sum_compiled_ab', mul_sum_compiled, a, b)
benchmark('min_sum_compiled_AA', min_sum_compiled, A, A)
benchmark('mul_sum_compiled_aa', mul_sum_compiled, a, a)
benchmark('min_sum_AB', min_sum, A, B)
benchmark('mul_sum_ab', mul_sum, a, b)
benchmark('min_sum_AA', min_sum, A, A)
benchmark('mul_sum_aa', mul_sum, a, a)
Versions
3.1.0.dev20230802+cpu
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @wconstab