8000 Fix evaluate_expr to include suppress_guards_tls in cache key by aorenste · Pull Request #152661 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix evaluate_expr to include suppress_guards_tls in cache key #152661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
18 changes: 15 additions & 3 deletions torch/fx/experimental/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
# save_tracked_fakes: saves a snapshot of the TrackedFake list.
# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
#
# name: the name of the function being recorded. Normally (and by default) this
# is taken from the decorated function but can be set if you need to override
# it.
#
# When to save the list of TrackedFake?
# =====================================
# We should save the list of TrackedFake whenever the translation validation
Expand All @@ -225,15 +229,19 @@ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
# At the moment, there are 2 methods that save the list:
# - ShapeEnv.evaluate_expr
# - ShapeEnv.defer_runtime_assert
def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
def record_shapeenv_event(
*, save_tracked_fakes: bool = False, name: Optional[str] = None
) -> Callable:
def decorator(fn: Callable) -> Callable:
assert callable(fn)
args = inspect.getfullargspec(fn).args
assert args and args[0] == "self", (
"record_shapeenv_event should only wrap methods on ShapeEnv; refactor your "
"code so that it calls into a method on ShapeEnv"
)
name = fn.__name__
nonlocal name
if name is None:
name = fn.__name__

@functools.wraps(fn)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -281,7 +289,11 @@ def retlog(r):
)
# Record the event for 'fn'.
event = ShapeEnvEvent(
fn, list(args), kwargs, tracked_fakes, name=fn.__name__
fn,
list(args),
kwargs,
tracked_fakes,
name=name,
)
# Play the event on this ShapeEnv.
# NB: It's important to put the event first, because running
Expand Down
30 changes: 23 additions & 7 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3936,7 +3936,8 @@ def _add_fx_node_metadata(self, node: torch.fx.Node) -> None:
node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index()
node.meta[CURRENT_NODE_KEY] = get_current_node()

def _suppress_guards_tls(self) -> bool:
@staticmethod
def _suppress_guards_tls() -> bool:
return getattr(TLS, "suppress_guards", False)

@record_shapeenv_event()
Expand Down Expand Up @@ -6811,8 +6812,6 @@ def evaluate_sym_node(
sym_node.expr, sym_node.hint, sym_node.fx_node, size_oblivious
)

@lru_cache(256)
@record_shapeenv_event(save_tracked_fakes=True)
def evaluate_expr(
self,
orig_expr: sympy.Basic,
Expand All @@ -6821,6 +6820,27 @@ def evaluate_expr(
size_oblivious: bool = False,
*,
forcing_spec: bool = False,
) -> sympy.Basic:
"""
Given an expression, evaluates it, adding guards if necessary
"""

# Add extra state that evaluate_expr() depends on.
suppress_guards_tls = ShapeEnv._suppress_guards_tls()
Copy link
Contributor
@laithsakka laithsakka May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the outer most API does take shape env as input, wonder if we should have just made
this an instance field instead of global state field.
suppress_guards
and
suppress_guards_stack

@contextmanager
def _suppress_guards(shape_env: ShapeEnv) -> Iterator[None]:
    shape_env._suppress_guards_enter()
    try:
        yield
    finally:
        shape_env._suppress_guards_exit()
        
   @record_shapeenv_event()
    def _suppress_guards_enter(self) -> None:
        if not hasattr(TLS, "suppress_guards_stack"):
            TLS.suppress_guards_stack = []
        old = self._suppress_guards_tls()
        TLS.suppress_guards_stack.append(old)
        TLS.suppress_guards = True

    @record_shapeenv_event()
    def _suppress_guards_exit(self) -> None:
        old = (
            TLS.suppress_guards_stack.pop()
            if len(TLS.suppress_guards_stack) > 0
            else False
        )
        TLS.suppress_guards = old

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is used only in one place

        # Needed to make sure we don't accidentally specialize any symbols
        assert self.fake_tensor_mode.shape_env is not None
        env = self.fake_tensor_mode.shape_env
        self.stack.enter_context(
            torch.fx.experimental.symbolic_shapes._suppress_guards(env)
        )
        return (
            str(CompileContext.current_compile_id()),
            inputs,
            sizes,
            scalars,
        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making it an instance var is an orthogonal issue - you'd still need to tell the cache that the variable should participate in the cache key (lru_cache doesn't automatically include instance vars in the cache key)

return self._inner_evaluate_expr(
orig_expr, hint, fx_node, size_oblivious, forcing_spec, suppress_guards_tls
)

@lru_cache(256)
@record_shapeenv_event(save_tracked_fakes=True, name="evaluate_expr")
Copy link
Contributor
@laithsakka laithsakka May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mean is it bad if it stayed _inner_evaluate_expr, whats the motivation for the renaming

def _inner_evaluate_expr(
self,
orig_expr: sympy.Basic,
hint: Optional[Union[int, bool, float]],
fx_node: Optional[torch.fx.Node],
size_oblivious: bool,
forcing_spec: bool,
_suppress_guards_tls: bool,
) -> sympy.Basic:
try:
return self._evaluate_expr(
Expand Down Expand Up @@ -6852,10 +6872,6 @@ def _evaluate_expr(
*,
forcing_spec: bool = False,
) -> sympy.Basic:
"""
Given an expression, evaluates it, adding guards if necessary
"""

# TODO: split conjunctions and evaluate them separately

if isinstance(
44B6 Expand Down
Loading
0