8000 Case study of torch.compile / cpp inductor on CPU: min_sum / mul_sum with 1d / matmul-like with static / dynamic shapes · Issue #106614 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Case study of torch.compile / cpp inductor on CPU: min_sum / mul_sum with 1d / matmul-like with static / dynamic shapes #106614
@vadimkantorov

Description

@vadimkantorov

🐛 Describe the bug

(I'll add actual benchmarking details and logs and output_code.py in a bit)

I'm doing min_sum and mul_sum in two setups:

  1. (D, ) x (D, ) -> scalar
  2. (B, N, 1, D) x (B, 1, N, D) -> (B, N, N)

Case (1) is similar to computing some sort of norm of a given input or a distance between two inputs.
Case (2) is matmul-like and is similar to computing distances between all pairs of the batch (mm / pdist / cdist).

When running as python3 bench_minmul_sum.py --enable-dynamic --avx512 --verbose1

Findings and questions regarding dynamic shapes:

  1. Dynamic / static shape options are not confined to the torch.compile call (which is super-unintuitive and brittle). When using --enable-dynamic, all produced output_code.py contain dynamic shapes (meaning, dynamic shapes ended up using not only for min_sum_compiled_dynamic but also for min_sum_compiled)
  2. When not using --enable-dynamic, all produced output_code.py contain static shapes - ideally this may lead to full loop unrolling for small dims, but currently loop unrolling may only be done by the underlying C++ compiler, loop unrolling pragmas are currently missing from the generated C++
  3. Dynamic shapes output_code.py do not record divisibility and always contain an extra tail loop.
  4. Ideally, I would like to have static shapes for 1d ops (inner dim is static or there are only a few possible values in my usecase); static shapes for the inner dim of the 2d ops and dynamic shapes for the batch dims (like can be done for torch.onnx.export). Currently parametrizing static or dynamic shapes is unpredictable :( This is not very good :(
  5. Why is #pragma omp simd simdlen(8) used? Apparently it's correct, but if it's the register-size-in-float32, I would expect it to be simdlen(16), no? Also, I had hard time to find out the resulting assembly and whether the compiled did loop unrolling (there was no #pragma omp unroll). How can one do that? Could the debug option print the produced assembly too? (obtained from objdump). Currently there are also too much logs, hard to get through all of them.
  6. Better explaining of the tiling story to the users is important - for matmul-like ops especially. And these custom matmul-like ops are quite common in kernel methods and similar (see pykeops): max-sum operation #59545 [Feature Request] Compile compatible Neighborhood Algorithms for large Tensors #97006 Direct Implementation of K-Nearest neighbor (KNN) in pytorch #71386

Findings and questions regarding NaNs:

  1. min_sum produced codes have NaN handling
  2. It can be good to give an asserting or a hint to the compiler that the inputs will not contain NaNs, so being NaN-compliant is not needed. This NaN handling might be a perf hit for small 1d vectors.

Findings regarding mul_sum (which is equivalent to matrix product or dot product):

  1. The mul_sum pattern is not recognized and gemm call or dot call is not produced. Seems that no tiling is done despite the fact that sum-reduction is used.

Findings regarding the benchmarking:

  1. On some Linux platforms (e.g. Windows WSLv1) it can be hard to account for CPU throttling. It would be good if PyTorch had some recommendations on how to account for it properly. And on laptops on battery there is often quite severe CPU throttling.
  2. config.cpp.simdlen is 512 despite that without ATEN_CPU_CAPABILITY=avx512, config.show() shows CPU capability usage: AVX2 (which 256-bits registers only). And it seems that this ATEN_CPU_CAPABILITY of only 256-bits is discovered automatically, despite that the laptop supports avx512

Misc:

  1. NVidia Triton will soon merge more direct and discoverable PyTorch + compile support: Add PyTorch platform handler triton-inference-server/python_backend#282, they are even considering of making torch.compile enabled as default option: Add PyTorch platform handler triton-inference-server/python_backend#282 (comment). This means that better telemetry and predictability about torch.compile would soon be more important. E.g. being able to completely save the best benchmarked kernels / selected cudnn algos to some file/database and then providing them as is when deploying to new servers. This is important to force the same algos/impls at every startup (I assume currently the selected cudnn algos might change if at algo benchmarking time, someone at a shared server e.g. occupies the needed benchmarking memory) and giving some indication / hooks when it was not possible.
  2. Better compiler debug output visualization report is needed similar to https://godbolt.org for C++ -> Assembly. Maybe some HTML report containing the Python source code so that one can click on compiled function and have C++ / CUDA / Triton / Assembly shown for inspection?
# bench_minmul_sum.py

import os
import sys
import time

enable_dynamic = '--enable-dynamic' in sys.argv

if '--avx512' in sys.argv:
    os.environ['ATEN_CPU_CAPABILITY'] = 'avx512'
if '--verbose1' in sys.argv:
    os.environ['TORCH_COMPILE_DEBUG'] = '1'
if '--verbose2' in sys.argv:
    os.environ['TORCH_LOGS'] = 'output_code'
if '--simd512':
    from torch._inductor import config
    config.cpp.simdlen = 512

from torch._inductor import config
print('config.cpp.simdlen', config.cpp.simdlen)
print([line for line in torch.__config__.show().splitlines() if 'CPU capability' in line])

import torch

min_sum = lambda a, b: torch.min(a, b).sum(-1)
mul_sum = lambda a, b: torch.mul(a, b).sum(-1)

min_sum_compiled_dynamic = torch.compile(min_sum, dynamic = enable_dynamic)
mul_sum_compiled_dynamic = torch.compile(mul_sum, dynamic = enable_dynamic)
min_sum_compiled = torch.compile(min_sum)
mul_sum_compiled = torch.compile(mul_sum)

# 192 is 12 float32x16 (512-bit registers) or 24 float32x8 (256-bit registers)
static_shape = (192,)
dynamic_shape = (6, 183, 192)

def benchmark(name, f, a, b, K = 100):
    tic = time.time()
    for k in range(K):
        f(a, b)
    print(name, a.shape, b.shape, (time.time() - tic) * 1000, 'ms')

# warmup
for k in range(5):
    A = torch.rand(*dynamic_shape, dtype = torch.float32).unsqueeze(2)
    B = torch.rand(*dynamic_shape, dtype = torch.float32).unsqueeze(1)
    a = torch.rand(*static_shape, dtype = torch.float32)
    b = torch.rand(*static_shape, dtype = torch.float32)

    if enable_dynamic:
        min_sum_compiled_dynamic(A, B)
        mul_sum_compiled_dynamic(A, B)
        mul_sum_compiled_dynamic(A, B)
        min_sum_compiled_dynamic(A, A)
    min_sum_compiled(A, B)
    mul_sum_compiled(A, B)
    min_sum_compiled(A, A)
    mul_sum_compiled(A, A)
    min_sum(A, B)
    mul_sum(A, B)
    min_sum(A, A)
    mul_sum(A, A)

    if enable_dynamic:
        min_sum_compiled_dynamic(a, b)
        mul_sum_compiled_dynamic(a, b)
        min_sum_compiled_dynamic(a, a)
        mul_sum_compiled_dynamic(a, a)
    min_sum_compiled(a, b)
    mul_sum_compiled(a, b)
    min_sum_compiled(a, a)
    mul_sum_compiled(a, a)
    min_sum(a, b)
    mul_sum(a, b)
    min_sum(a, a)
    mul_sum(a, a)

# benchmark
A = torch.rand(*dynamic_shape, dtype = torch.float32).unsqueeze(2)
B = torch.rand(*dynamic_shape, dtype = torch.float32).unsqueeze(1)
a = torch.rand(*static_shape, dtype = torch.float32)
b = torch.rand(*static_shape, dtype = torch.float32)

if enable_dynamic:
    benchmark('min_sum_compiled_dynamic_AB', min_sum_compiled_dynamic, A, B)
    benchmark('mul_sum_compiled_dynamic_ab', mul_sum_compiled_dynamic, a, b)
    benchmark('min_sum_compiled_dynamic_AA', min_sum_compiled_dynamic, A, A)
    benchmark('mul_sum_compiled_dynamic_aa', mul_sum_compiled_dynamic, a, a)
benchmark('min_sum_compiled_AB', min_sum_compiled, A, B)
benchmark('mul_sum_compiled_ab', mul_sum_compiled, a, b)
benchmark('min_sum_compiled_AA', min_sum_compiled, A, A)
benchmark('mul_sum_compiled_aa', mul_sum_compiled, a, a)
benchmark('min_sum_AB', min_sum, A, B)
benchmark('mul_sum_ab', mul_sum, a, b)
benchmark('min_sum_AA', min_sum, A, A)
benchmark('mul_sum_aa', mul_sum, a, a)

Versions

3.1.0.dev20230802+cpu

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @wconstab

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: custom-operatorscustom operators, custom ops, custom-operators, custom-opsmodule: dynamic shapesmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0