|
44 | 44 | )
|
45 | 45 | from torch.utils._device import DeviceContext
|
46 | 46 |
|
47 |
| -from ..exc import unimplemented |
| 47 | +from .. import graph_break_hints |
| 48 | +from ..exc import unimplemented_v2 |
48 | 49 | from ..guards import GuardBuilder, install_guard
|
49 | 50 | from ..polyfills import NoEnterTorchFunctionMode
|
50 | 51 | from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
@@ -567,8 +568,13 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
|
567 | 568 | if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
|
568 | 569 | return res
|
569 | 570 |
|
570 |
| - unimplemented( |
571 |
| - f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented" |
| 571 | + unimplemented_v2( |
| 572 | + gb_type="TypeError from user code", |
| 573 | + context=f"{fn=}, {args=}, {kwargs=}", |
| 574 | + explanation="All __torch_function__ overrides for returned NotImplemented", |
| 575 | + hints=[ |
| 576 | + *graph_break_hints.USER_ERROR, |
| 577 | + ], |
572 | 578 | )
|
573 | 579 |
|
574 | 580 |
|
@@ -621,9 +627,17 @@ def var_getattr(self, tx: "InstructionTranslator", name):
|
621 | 627 | # base tensors, custom attribute accesses will graph break.
|
622 | 628 | import torch
|
623 | 629 |
|
| 630 | + # I think only `_base` is breaking because we aren't modelling view |
| 631 | + # relationship perfectly in some scenarios. |
624 | 632 | if name in banned_attrs:
|
625 |
| - unimplemented( |
626 |
| - f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" |
| 633 | + unimplemented_v2( |
| 634 | + gb_type="Unsupported tensor subclass attribute access", |
| 635 | + context=f"{name}", |
| 636 | + explanation="`torch.compile` currently can't trace this", |
| 637 | + hints=[ |
| 638 | + f"Avoid accessing {name} of tensor subclass in torch.compile region", |
| 639 | + *graph_break_hints.SUPPORTABLE, |
| 640 | + ], |
627 | 641 | )
|
628 | 642 |
|
629 | 643 | # Handle non-overriden attributes inherited from `torch.Tensor`.
|
@@ -676,8 +690,15 @@ def var_getattr(self, tx: "InstructionTranslator", name):
|
676 | 690 | )
|
677 | 691 |
|
678 | 692 | elif attr_is_overriden:
|
679 |
| - unimplemented( |
680 |
| - f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950 |
| 693 | + unimplemented_v2( |
| 694 | + gb_type="Unsupported tensor subclass overriden attribute access", |
| 695 | + context=f"{name}", |
| 696 | + explanation="`torch.compile` only support tracing certain types of overriden tensor subclass attributes", |
| 697 | + hints=[ |
| 698 | + f"Avoid accessing {name} of tensor subclass in torch.compile region", |
| 699 | + f"Renaming attribute `{name}` of type {self.class_type}", |
| 700 | + *graph_break_hints.SUPPORTABLE, |
| 701 | + ], |
681 | 702 | )
|
682 | 703 |
|
683 | 704 | return super().var_getattr(tx, name)
|
@@ -709,9 +730,15 @@ def call_method(
|
709 | 730 | import torch
|
710 | 731 |
|
711 | 732 | if _is_attr_overidden(tx, self, name):
|
712 |
| - unimplemented( |
713 |
| - f"Calling overridden method {name} on a tensor" |
714 |
| - " subclass with a __torch_function__ override is not supported" |
| 733 | + unimplemented_v2( |
| 734 | + gb_type="Tensor subclass overriden method call", |
| 735 | + context=f"{name}", |
| 736 | + explanation="`torch.compile` currently can't trace this", |
| 737 | + hints=[ |
| 738 | + f"Avoid calling {name} of tensor subclass in torch.compile region", |
| 739 | + f"Renaming method `{name}` of type {self.class_type}", |
| 740 | + *graph_break_hints.SUPPORTABLE, |
| 741 | + ], |
715 | 742 | )
|
716 | 743 |
|
717 | 744 | # [Note: __torch_function__] Currently we only support methods that are defined on tensor
|
|
0 commit comments