8000 [Set] Handle exception in ConstantVariable operation by guilhermeleobas · Pull Request #152987 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Set] Handle exception in ConstantVariable operation #152987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,18 @@ def test_tuple_contains(a, b):
return a + b
return a - b

@make_test
def test_set_invalid_ConstantVariable_op(a, b):
s = set({"banana", "apple", "orange"})
try:
s - 1
except TypeError:
return a + b
except Exception:
return a - b
else:
return a * b

@make_test
def test_set_update_bytecode(x):
# This produces bytecode SET_UPDATE since python 3.9
Expand Down
Empty file.
Empty file.
Empty file.
17 changes: 12 additions & 5 deletions torch/_dynamo/variables/base.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from .. import graph_break_hints, variables
from ..current_scope_id import current_scope_id
from ..exc import unimplemented_v2
from ..exc import raise_observed_exception, unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, Source
from ..utils import cmp_name_to_op_mapping, istype
Expand Down Expand Up @@ -515,11 +515,18 @@ def call_method(
hints=[],
)

return variables.ConstantVariable.create(
cmp_name_to_op_mapping[name](
self.as_python_constant(), other.as_python_constant()
try:
return variables.ConstantVariable.create(
cmp_name_to_op_mapping[name](
self.as_python_constant(), other.as_python_constant()
)
)
except Exception as e:
raise_observed_exception(
type(e),
tx,
args=[list(map(variables.ConstantVariable.create, e.args))],
)
)
hints = [
f"Avoid calling `{self.python_type_name()}.{name}` in your code.",
"Please report an issue to PyTorch.",
Expand Down
18 changes: 15 additions & 3 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
str_methods,
tensortype_to_dtype,
)
from .base import ValueMutationNew, VariableTracker
from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
from .constant import ConstantVariable
from .ctx_manager import EventVariable, StreamVariable
from .dicts import (
Expand Down Expand Up @@ -901,6 +901,12 @@ def constant_fold_handler(tx: "InstructionTranslator", args, kwargs):
*[x.as_python_constant() for x in args],
)
except Exception as exc:
raise_observed_exception(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamwen42, could you take a look at this change? Do you see any issues with replacing the unimplemented(...) call with raising an exception in Dynamo? If the exception isn't caught, it will propagate to CPython.

This change is needed to get some of the CPython set tests to pass.

type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
except AsPythonConstantNotImplementedError as exc:
unimplemented_v2(
gb_type="constant fold exception",
context=f"attempted to run function {fn} with arguments {args}",
Expand All @@ -922,14 +928,20 @@ def constant_fold_handler(tx: "InstructionTranslator", args, kwargs):
k: v.as_python_constant() for k, v in kwargs.items()
},
)
except Exception as exc:
except AsPythonConstantNotImplementedError as exc:
unimplemented_v2(
gb_type="constant fold exception",
context=f"attempted to run function {fn} with arguments {args} {kwargs}",
context=f"attempted to run function {fn} with arguments {args}",
explanation="Encountered exception when attempting to constant fold.",
hints=[*graph_break_hints.DYNAMO_BUG],
from_exc=exc,
)
except Exception as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
return VariableTracker.build(tx, res)

handlers.append(constant_fold_handler)
Expand Down
7 changes: 6 additions & 1 deletion torch/_dynamo/variables/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,12 @@ def call_method(
)
return SymNodeVariable.create(tx, proxy, add_target)
else:
return ConstantVariable.create(op(self.value, add_target))
try:
return ConstantVariable.create(op(self.value, add_target))
except Exception as e:
raise_observed_exception(
type(e), tx, args=list(map(ConstantVariable.create, e.args))
)
elif isinstance(self.value, bytes) and name == "decode":
method = getattr(self.value, name)
return ConstantVariable.create(method(*const_args, **const_kwargs))
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .. import graph_break_hints, variables
from ..exc import (
IncorrectUsage,
ObservedException,
UncapturedHigherOrderOpError,
unimplemented,
unimplemented_v2,
Expand All @@ -72,7 +73,7 @@ def deco(fn):
def graph_break_as_hard_error(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Unsupported as e:
except (Unsupported, ObservedException) as e:
msg = " Scroll up to find out what causes the graph break."
raise UncapturedHigherOrderOpError(reason + msg) from e

Expand Down
Loading
0