8000 [WIP] add envvar to bisect number of graphs compiled by eellison · Pull Request #153275 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[WIP] add envvar to bisect number of graphs compiled #153275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions test/dynamo/test_compiler_bisector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch._inductor import config
from torch._inductor.compiler_bisector import CompilerBisector
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch.library import _scoped_library, Library
from torch.testing._internal.inductor_utils import HAS_CUDA

Expand Down Expand Up @@ -149,6 +150,36 @@ def test_fn():
self.assertEqual(out.subsystem, "inductor_fallback_random")
self.assertTrue("inductor_fallback_random" in out.debug_info)

# todo - incorporate to compiler bisector
@torch._dynamo.config.patch(debug_max_graphs=1)
def test_bisecting_num_graphs(self):
from torch._dynamo.utils import counters

def foo(x):
out = x + 3
torch._dynamo.graph_break()
return out * 2

out = torch.compile(foo)(torch.ones([4], device="cuda"))
self.assertEqual(out, foo(torch.ones([4], device="cuda")))
self.assertEqual(counters["aot_autograd"]["total"], 1)

# todo - incorporate to compiler bisector
@torch._dynamo.config.patch(debug_max_backend_graphs=1)
def test_bisecting_backend_graphs(self):
from torch._dynamo.utils import counters

def foo(x):
out = x + 3
torch._dynamo.graph_break()
return out * 2

inp = torch.ones([4], device="cuda")
out, code = run_and_get_code(torch.compile(foo), inp)
self.assertEqual(len(code), 1)
self.assertEqual(counters["aot_autograd"]["total"], 2)
self.assertEqual(out, foo(torch.ones([4], device="cuda")))

def test_crossref(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor x) -> Tensor")
Expand Down
32 changes: 32 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2289,6 +2289,32 @@ def compiled_with_cxx11_abi() -> builtins.bool:
from torch.utils.dlpack import from_dlpack, to_dlpack


def check_max_graphs() -> bool:
"""
If we have hit a user specified max number of graphs, skip this frame.

Then, return if we have hit the maximum number of graphs for the given backend
before falling back to aot_eager.
"""
from torch._dynamo.utils import GraphsCompiledState

max_compiled_graphs = torch._dynamo.config.debug_max_graphs
max_backend_graphs = torch._dynamo.config.debug_max_backend_graphs
if max_compiled_graphs is None and max_backend_graphs is None:
return None

num_graphs = GraphsCompiledState.increment()
num_graphs = GraphsCompiledState.get_num_graphs()
if max_compiled_graphs is not None and num_graphs > builtins.int(
max_compiled_graphs
):
raise torch._dynamo.exc.SkipFrame(f"Hit max graph limit: {max_compiled_graphs}")

return max_backend_graphs is not None and num_graphs > builtins.int(
max_backend_graphs
)


class _TorchCompileInductorWrapper:
compiler_name = "inductor"

Expand Down Expand Up @@ -2361,6 +2387,11 @@ def apply_options(self, options: _Optional[dict[str, _Any]]):
def __call__(self, model_, inputs_):
from torch._inductor.compile_fx import compile_fx

if check_max_graphs():
return _TorchCompileWrapper(
"aot_eager", "default", {}, self.dynamic
).__call__(model_, inputs_)

return compile_fx(model_, inputs_, config_patches=self.config)

def get_compiler_config(self):
Expand Down Expand Up @@ -2406,6 +2437,7 @@ def __eq__(self, other):
)

def __call__(self, model_, inputs_):
check_max_graphs()
return self.compiler_fn(model_, inputs_, **self.kwargs)

def reset(self):
Expand Down
9 changes: 8 additions & 1 deletion torch/_dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@
from .mutation_guard import GenerationTracker
from .pgo import reset_code_state
from .symbolic_convert import TensorifyState
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
from .utils import (
graph_break_reasons,
GraphsCompiledState,
guard_failures,
orig_code_map,
reset_frame_count,
)


# Register polyfill functions
Expand Down Expand Up @@ -131,6 +137,7 @@ def reset() -> None:
callback_handler.clear()
GenerationTracker.clear()
TensorifyState.clear()
GraphsCompiledState.clear()
torch._dynamo.utils.warn_once_cache.clear()
torch._dynamo.utils.user_obj_id_to_weakref.clear()
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@
# [@compile_ignored: runtime_behaviour]
same_two_models_use_fp64 = True

# maximum number of dynamo graphs to compile.
# if we exceed this limit, we will raise a SkipFrame
debug_max_graphs = os.environ.get("TORCH_BISECT_MAX_GRAPHS", None)

# maximum number of dynamo graphs to invoke with compiled bakcend
# if we exeed this limit, we will defer to aot_eager
debug_max_backend_graphs = os.environ.get("TORCH_BISECT_MAX_BACKEND_GRAPHS", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll wait for a second reviewer to agree on bikeshedding, but a few thoughts:

  • On having two configs: separate configs for "run the first X graphs through dynamo and fall the rest back to eager", vs. "of the graphs that go through dynamo, run the first Y through inductor and have the rest go through a simple backend" seems ok. I guess I'm not convinced that it's necessary, but having both knobs available for debugging seems useful

  • for the second config, at first I was surprised that you hardcoded the "fallback" backend to aot_eager, but if we're thinking about it as an extra option mainly to find bugs in inductor then it does sound relatively reasonable? We could have yet another config to specify the default backend but that feels like overkill. Also aot_eager_decomp_partition seems to have a few cases where it does not give the same numerics as eager due to our decomps being bad, so aot_eager

  • if this backend is mainly for "debug inductor problems", wdyt of debug_max_inductor_graphs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh, sorry i reopened since this had internal cmomits. #154543

mind commenting there ?


# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
# When this flag is set to False, we introduce a graph break instead of capturing.
# This requires dynamic_shapes to be True.
Expand Down
20 changes: 20 additions & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4544,6 +4544,26 @@ def record(cls):
cls.end()


class GraphsCompiledState:
"""
Tracks number of compiled graphs.
"""

num_graphs: int = 0

@classmethod
def clear(cls) -> None:
cls.num_graphs = 0

@classmethod
def increment(cls) -> None:
cls.num_graphs += 1

@classmethod
def get_num_graphs(cls) -> int:
return cls.num_graphs


def set_feature_use(feature: str, usage: bool):
"""
Records whether we are using a feature
Expand Down
Loading
0