10000 [dynamo] replace `unimplemented` with `unimplemented_v2` in `variable… · pytorch/pytorch@429c3bc · GitHub
[go: up one dir, main page]

Skip to content

Commit 429c3bc

Browse files
committed
[dynamo] replace unimplemented with unimplemented_v2 in variables/torch_functions.py
This addresses part of #147913. ghstack-source-id: 26b24fa Pull Request resolved: #151278
1 parent ba4da63 commit 429c3bc

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

test/dynamo/test_subclasses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
769769
def fn(x):
770770
return x.ndim
771771

772-
msg = "Currently only support accessing overridden attributes that are functions or properties, but got <class 'int'>"
772+
msg = "`torch.compile` only support tracing certain types of overriden tensor subclass attributes"
773773
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
774774
x = torch.ones(2, 2).as_subclass(LocalSubclass)
775775
fn(x)

torch/_dynamo/variables/torch_function.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
)
4545
from torch.utils._device import DeviceContext
4646

47-
from ..exc import unimplemented
47+
from .. import graph_break_hints
48+
from ..exc import unimplemented_v2
4849
from ..guards import GuardBuilder, install_guard
4950
from ..polyfills import NoEnterTorchFunctionMode
5051
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
@@ -567,8 +568,13 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
567568
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
568569
return res
569570

570-
unimplemented(
571-
f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented"
571+
unimplemented_v2(
572+
gb_type="TypeError from user code",
573+
context=f"{fn=}, {args=}, {kwargs=}",
574+
explanation="All __torch_function__ overrides for returned NotImplemented",
575+
hints=[
576+
*graph_break_hints.USER_ERROR,
577+
],
572578
)
573579

574580

@@ -621,9 +627,17 @@ def var_getattr(self, tx: "InstructionTranslator", name):
621627
# base tensors, custom attribute accesses will graph break.
622628
import torch
623629

630+
# I think only `_base` is breaking because we aren't modelling view
631+
# relationship perfectly in some scenarios.
624632
if name in banned_attrs:
625-
unimplemented(
626-
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
633+
unimplemented_v2(
634+
gb_type="Unsupported tensor subclass attribute access",
635+
context=f"{name}",
636+
explanation="`torch.compile` currently can't trace this",
637+
hints=[
638+
f"Avoid accessing {name} of tensor subclass in torch.compile region",
639+
*graph_break_hints.SUPPORTABLE,
640+
],
627641
)
628642

629643
# Handle non-overriden attributes inherited from `torch.Tensor`.
@@ -676,8 +690,15 @@ def var_getattr(self, tx: "InstructionTranslator", name):
676690
)
677691

678692
elif attr_is_overriden:
679-
unimplemented(
680-
f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950
693+
unimplemented_v2(
694+
gb_type="Unsupported tensor subclass overriden attribute access",
695+
context=f"{name}",
696+
explanation="`torch.compile` only support tracing certain types of overriden tensor subclass attributes",
697+
hints=[
698+
f"Avoid accessing {name} of tensor subclass in torch.compile region",
699+
f"Renaming attribute `{name}` of type {self.class_type}",
700+
*graph_break_hints.SUPPORTABLE,
701+
],
681702
)
682703

683704
return super().var_getattr(tx, name)
@@ -709,9 +730,15 @@ def call_method(
709730
import torch
710731

711732
if _is_attr_overidden(tx, self, name):
712-
unimplemented(
713-
f"Calling overridden method {name} on a tensor"
714-
" subclass with a __torch_function__ override is not supported"
733+
unimplemented_v2(
734+
gb_type="Tensor subclass overriden method call",
735+
context=f"{name}",
736+
explanation="`torch.compile` currently can't trace this",
737+
hints=[
738+
f"Avoid calling {name} of tensor subclass in torch.compile region",
739+
f"Renaming method `{name}` of type {self.class_type}",
740+
*graph_break_hints.SUPPORTABLE,
741+
],
715742
)
716743

717744
# [Note: __torch_function__] Currently we only support methods that are defined on tensor

0 commit comments

Comments
 (0)
0