From a1c6178b2cbc2b560fadd83518284c7bfcee10d6 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Thu, 1 May 2025 20:05:13 -0700 Subject: [PATCH 1/2] Fix evaluate_expr to include suppress_guards_tls in cache key [ghstack-poisoned] --- torch/fx/experimental/recording.py | 18 +++++++++++++++--- torch/fx/experimental/symbolic_shapes.py | 22 +++++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index c8095357d06f..bb54eba11384 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -214,6 +214,10 @@ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: # save_tracked_fakes: saves a snapshot of the TrackedFake list. # This is used when calling ShapeEnv.produce_guards at arbitrary points in time. # +# name: the name of the function being recorded. Normally (and by default) this +# is taken from the decorated function but can be set if you need to override +# it. +# # When to save the list of TrackedFake? # ===================================== # We should save the list of TrackedFake whenever the translation validation @@ -225,7 +229,9 @@ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: # At the moment, there are 2 methods that save the list: # - ShapeEnv.evaluate_expr # - ShapeEnv.defer_runtime_assert -def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable: +def record_shapeenv_event( + *, save_tracked_fakes: bool = False, name: Optional[str] = None +) -> Callable: def decorator(fn: Callable) -> Callable: assert callable(fn) args = inspect.getfullargspec(fn).args @@ -233,7 +239,9 @@ def decorator(fn: Callable) -> Callable: "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your " "code so that it calls into a method on ShapeEnv" ) - name = fn.__name__ + nonlocal name + if name is None: + name = fn.__name__ @functools.wraps(fn) def wrapper(*args, **kwargs): @@ -281,7 +289,11 @@ def retlog(r): ) # Record the event for 'fn'. event = ShapeEnvEvent( - fn, list(args), kwargs, tracked_fakes, name=fn.__name__ + fn, + list(args), + kwargs, + tracked_fakes, + name=name, ) # Play the event on this ShapeEnv. # NB: It's important to put the event first, because running diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 22673438e32b..fe1e6d3ab1fb 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -3966,7 +3966,8 @@ def _add_fx_node_metadata(self, node: torch.fx.Node) -> None: node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index() node.meta[CURRENT_NODE_KEY] = get_current_node() - def _suppress_guards_tls(self) -> bool: + @staticmethod + def _suppress_guards_tls() -> bool: return getattr(TLS, "suppress_guards", False) @record_shapeenv_event() @@ -6841,8 +6842,6 @@ def evaluate_sym_node( sym_node.expr, sym_node.hint, sym_node.fx_node, size_oblivious ) - @lru_cache(256) - @record_shapeenv_event(save_tracked_fakes=True) def evaluate_expr( self, orig_expr: sympy.Basic, @@ -6851,6 +6850,23 @@ def evaluate_expr( size_oblivious: bool = False, *, forcing_spec: bool = False, + ) -> sympy.Basic: + # Add extra state that evaluate_expr() depends on. + suppress_guards_tls = ShapeEnv._suppress_guards_tls() + return self._inner_evaluate_expr( + orig_expr, hint, fx_node, size_oblivious, forcing_spec, suppress_guards_tls + ) + + @lru_cache(256) + @record_shapeenv_event(save_tracked_fakes=True, name="evaluate_expr") + def _inner_evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[int, bool, float]], + fx_node: Optional[torch.fx.Node], + size_oblivious: bool, + forcing_spec: bool, + _suppress_guards_tls: bool, ) -> sympy.Basic: try: return self._evaluate_expr( From e2aa64b399cdb7b63c8bccdb172cd6b0613549db Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Fri, 2 May 2025 13:26:01 -0700 Subject: [PATCH 2/2] Update on "Fix evaluate_expr to include suppress_guards_tls in cache key" ShapeEnv.evaluate_expr() behaves differently based on the (tls) global "suppress_guards" - so its cache key needs to include that value. This came up because #152662 triggered it in the test `test/dynamo/test_exc.py::ExcTests::test_trigger_bisect_on_error` - fixing this caused that test to work again. cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv [ghstack-poisoned] --- torch/fx/experimental/symbolic_shapes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index fe1e6d3ab1fb..3a65e65757c4 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -6868,6 +6868,10 @@ def _inner_evaluate_expr( forcing_spec: bool, _suppress_guards_tls: bool, ) -> sympy.Basic: + """ + Given an expression, evaluates it, adding guards if necessary + """ + try: return self._evaluate_expr( orig_expr, @@ -6898,10 +6902,6 @@ def _evaluate_expr( *, forcing_spec: bool = False, ) -> sympy.Basic: - """ - Given an expression, evaluates it, adding guards if necessary - """ - # TODO: split conjunctions and evaluate them separately if isinstance(