8000 Add torch.profile benchmarking function to feedback_fns by exclamaforte · Pull Request #153579 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add torch.profile benchmarking function to feedback_fns #153579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def start(self):
"TORCH_WARM_POOL": "0",
# Some internal usages need a modified LD_LIBRARY_PATH.
"LD_LIBRARY_PATH": get_ld_library_path(),
# This will cause the subprocs to profile using the profiler.
"TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING": "1"
if config.profile_bandwidth_with_do_bench_using_profiling
else "0",
}
if self.device is not None:
extra_env[CUDA_VISIBLE_DEVICES] = str(self.device)
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/codegen/cpp_template_kernel.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sympy.parsing.sympy_parser import parse_expr

import torch
from torch._inductor.utils import do_bench_using_profiling
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import SymT

Expand Down Expand Up @@ -566,6 +567,9 @@ def precompile(self) -> None:

def benchmark(self, *args, out) -> float:
assert self.bmreq is not None
if config.profile_bandwidth_with_do_bench_using_profiling:
algo = self.bmreq.make_run_fn(*args, out=out)
return do_bench_using_profiling(algo)
return self.bmreq.benchmark(*args, out=out)

def hash_key(self) -> str:
Expand Down
10 changes: 6 additions & 4 deletions torch/_inductor/codegen/cuda/cuda_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

from sympy import Expr, symbols

import torch._inductor.config as config
from torch import dtype as torch_dtype
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.utils import Placeholder
from torch._inductor.utils import do_bench_using_profiling, Placeholder
from torch.utils._sympy.value_ranges import ValueRanges

from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
Expand Down Expand Up @@ -593,9 +594,10 @@ def precompile(self) -> None:

def benchmark(self, *args, out) -> float:
assert self.bmreq is not None
return self.bmreq.benchmark(
*args, out=out
) # @TODO: Hack for ensuring that Cutlass Kernel is preferred
if config.profile_bandwidth_with_do_bench_using_profiling:
algo = self.bmreq.make_run_fn(*args, out=out)
return do_bench_using_profiling(algo)
return self.bmreq.benchmark(*args, out=out)

def __str__(self) -> str:
return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
Expand Down
5 changes: 5 additions & 0 deletions torch/_inductor/codegen/rocm/rocm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from collections.abc import Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

import torch._inductor.config as config
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch._inductor.utils import do_bench_using_profiling

from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, TensorBox
from ...virtualized import V
Expand Down Expand Up @@ -247,6 +249,9 @@ def precompile(self) -> None:

def benchmark(self, *args, out) -> float:
assert self.bmreq is not None
if config.profile_bandwidth_with_do_bench_using_profiling:
algo = self.bmreq.make_run_fn(*args, out=out)
return do_bench_using_profiling(algo)
return self.bmreq.benchmark(*args, out=out)

def __str__(self) -> str:
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/codegen/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable

import torch
import torch._inductor.config as config
from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import (
Expand All @@ -14,6 +15,7 @@
Layout,
)
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.utils import do_bench_using_profiling
from torch._inductor.virtualized import V


Expand Down Expand Up @@ -113,6 +115,8 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
bm_func = mod.call

bm_func([*sym_inputs, *args])
if config.profile_bandwidth_with_do_bench_using_profiling:
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))

def hash_key(self) -> str:
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
convert_shape_to_inductor,
convert_shape_to_symint,
developer_warning,
do_bench_using_profiling,
get_dtype_size,
get_kernel_metadata,
GPU_ALIGN_BYTES,
Expand Down Expand Up @@ -4697,6 +4698,8 @@ def __init__(

def benchmark(self, *args, out) -> float: # type: ignore[no-untyped-def]
algo = self.to_callable()
if config.profile_bandwidth_with_do_bench_using_profiling:
return do_bench_using_profiling(lambda: algo(*args))
return benchmarker.benchmark(algo, args, {"out": out})

def call_name(self) -> str:
Expand Down
61 changes: 48 additions & 13 deletions torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from .runtime.triton_heuristics import FixedGrid
from .utils import (
ceildiv,
do_bench_using_profiling,
FakeIndentedBuffer,
get_dtype_size,
is_gpu,
Expand Down Expand Up @@ -1777,6 +1778,9 @@ def __init__(

def benchmark(self, *args, out):
assert self.bmreq is not None
if config.profile_bandwidth_with_do_bench_using_profiling:
algo = self.bmreq.make_run_fn(*args, out=out)
return do_bench_using_profiling(algo)
return self.bmreq.benchmark(*args, out=out)

def precompile(self):
Expand Down Expand Up @@ -1860,6 +1864,8 @@ def benchmark(self, *args, out):
out_new, tuple(out.size()), tuple(out.stride())
)
out.copy_(out_new) # for correctness checking
if config.profile_bandwidth_with_do_bench_using_profiling:
return do_bench_using_profiling(lambda: algo(*args))
return benchmarker.benchmark(algo, args, {})

def to_callable(self):
Expand Down Expand Up @@ -2065,6 +2071,24 @@ def create_precompile_key(
)


# Args to FeedbackFunctions
# timings: mapping from choices to the benchmark time
# name: name of the op
# input_nodes: list of input ir.py Nodes
# choices: list of choices
# profiled time: Callable that returns a dict mapping from choices to the profiled time
FeedbackFunction = Callable[
[
dict[ChoiceCaller, float],
str,
list[Any],
list[ChoiceCaller],
Callable[[], dict[ChoiceCaller, float]],
],
None,
]


class AlgorithmSelectorCache(PersistentCache):
"""
A persistent cache for algorithm selection results used in autotuning of GEMMs
Expand All @@ -2085,11 +2109,7 @@ def __init__(self, *args, **kwargs) -> None:
# of a particular key
self.precompile_cache: dict[str, Callable[[], None]] = {}
# list of callbacks that are called after benchmarking
self.feedback_saver_fns: list[
Callable[
[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None
]
] = []
self.feedback_saver_fns: list[FeedbackFunction] = []

clear_on_fresh_inductor_cache(self)

Expand Down Expand Up @@ -2236,8 +2256,28 @@ def do_autotuning(choices, precompile_fn):
name, input_nodes, timings, autotune_elapse, precompile_elapse
)

def profiler_bench_function():
# we're not running through the normal caching autotuner method here because we want to avoid returning
# the cached value.
# Avoid benchmarking in a separate process because it's not easy to signal to the TuningProcess that we
# should use the profiler.
with config.patch(
profile_bandwidth_with_do_bench_using_profiling=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this part I can't quite figure out.. For profile_bandwidth_with_do_bench_using_profiling=True, I don't understand why you need to check this config in many of the changes above. This looks like an existing config. If you're overriding the config here, then isn't there some underlying do_bench function that's already being configured to use profiling? Why do you need, e.g.,

        if config.profile_bandwidth_with_do_bench_using_profiling:
            algo = self.bmreq.make_run_fn(*args, out=out)
            return do_bench_using_profiling(algo)

For autotune_in_subproc=False, it seems a bit weird to me to quietly not use the subproc-based method if the user requested it. Seems like we should assert that this feature is just not compatible with do_bench_using_profiling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So profile_bandwidth_with_do_bench_using_profiling basically tells the CachingAutotuner to use the profiler. The AlgorithmSelectorCache doesn't use the CachingAutotuner, it calls the benchmark method of ChoiceCallers, which is why it's necessary to change this in many spots. The do_bench function is InductorBenchmarker/TritonBenchmarker.bench in this case, so we're replacing that with do_bench_using_profiling.

For autotune_in_subproc=False, it seems a bit weird to me to quietly not use the subproc-based method if the user requested it.

I can support autotune_in_subproc, but I would need to signal to the TuningProcess that it should use the benchmarker somehow, because the subprocs are all created at startup so just setting the envvar doesn't work. In status quo, they're all doing profiler benchmarking or none of them are. I can make this refactor but I don't think it's worth complicating TuningProcess. I'd rather just run them in the single process, since this change is mainly focused on data collection/logging.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can support autotune_in_subproc

Right. I wasn't proposing to support it. Isn't the new feature you're adding optional, or did I misread and it's active by default? I was just saying that we could fail if user wants both profiling and subprocs rather than quietly doing something else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the subprocs respect the flag, it should work fine, but we can't change it later on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean you can remove autotune_in_subproc=False, here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because for this logging code, we need it to switch between profiler and not-profiler, which only works if it's not using the subprocs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is why the code is passed in as an unevaluated function. It might run slower obviously, but the logging code can run the function if it wants the information, but it doesn't get run otherwise.

autotune_in_subproc=False,
):
return self.make_benchmark_fn(
choices, input_nodes, layout, input_gen_fns
)(choices)

for feedback_fn in self.feedback_saver_fns:
feedback_fn(timings, name, input_nodes, choices)
# re-benchmarking the same choices with profiler is a bit expensive, so pass it in as a thunk.
feedback_fn(
timings,
name,
input_nodes,
choices,
profiler_bench_function,
)

return timings

Expand Down Expand Up @@ -2917,12 +2957,7 @@ def key_of(node):
),
)

def add_feedback_saver(
self,
fn: Callable[
[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None
],
):
def add_feedback_saver(self, fn: FeedbackFunction):
self.feedback_saver_fns.append(fn)


Expand All @@ -2946,7 +2981,7 @@ def autotune_select_algorithm(*args, **kwargs):


def add_feedback_saver(
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None],
fn: FeedbackFunction,
):
global _ALGORITHM_SELECTOR_CACHE
if _ALGORITHM_SELECTOR_CACHE is None:
Expand Down
Loading
0