8000 [Dynamo] Replace `unimplemented` with `unimplemented_v2` in `torch/_d… · pytorch/pytorch@0423a7b · GitHub
[go: up one dir, main page]

Skip to content

Commit 0423a7b

Browse files
shinkwilliamwen42
authored andcommitted
[Dynamo] Replace unimplemented with unimplemented_v2 in torch/_dynamo/variables/nn_module.py (#151895)
Part of #147913 Replace `unimplemented` with`unimplemented_v2` in `torch/_dynamo/variables/nn_module.py` Pull Request resolved: #151895 Approved by: https://github.com/williamwen42 Co-authored-by: William Wen <william.wen42@gmail.com>
1 parent e2f9759 commit 0423a7b

File tree

1 file changed

+70
-8
lines changed

1 file changed

+70
-8
lines changed

torch/_dynamo/variables/nn_module.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from .. import graph_break_hints, trace_rules, variables
3636
from ..exc import (
3737
raise_observed_exception,
38-
unimplemented,
3938
unimplemented_v2,
4039
UnspecializeRestartAnalysis,
4140
Unsupported,
@@ -247,7 +246,17 @@ def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
247246
base = tx.output.get_submodule(self.module_key)
248247

249248
if object_has_getattribute(base):
250-
unimplemented("NNModuleVariable with custom __getattribute__")
249+
unimplemented_v2(
250+
gb_type="torch.nn.Module with a custom __getattribute__ defined",
251+
context=f"has_key_in_generic_dict {self} {key}",
252+
explanation="Dynamo does not support checking key existence "
253+
"on `nn.Module` instances that have a custom "
254+
"`__getattribute__` method defined.",
255+
hints=[
256+
"Avoid defining `__getattribute__` in your module.",
257+
*graph_break_hints.SUPPORTABLE,
258+
],
259+
)
251260

252261
if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
253262
mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
@@ -259,14 +268,40 @@ def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
259268
def _custom_getattr_fallback(self, base, tx, name, obj_source):
260269
"""Check for a __getattr__ and handle it specially if it is implemented"""
261270
if object_has_getattribute(base):
262-
unimplemented("torch.nn.Module with a custom __getattribute__ defined")
271+
unimplemented_v2(
272+
gb_type="torch.nn.Module with a custom __getattribute__ defined",
273+
context=f"var_getattr {self} {name}",
274+
explanation="Dynamo does not support checking key existence "
275+
"on `nn.Module` instances that have a custom "
276+
"`__getattribute__` method defined.",
277+
hints=[
278+
"Avoid defining `__getattribute__` in your module.",
279+
*graph_break_hints.SUPPORTABLE,
280+
],
281+
)
263282

264283
getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True)
265284
if getattr_fn is None:
266285
return None
267286

268287
if not isinstance(getattr_fn, types.FunctionType):
269-
unimplemented("torch.nn.Module with a non-function custom __getattr__")
288+
unimplemented_v2(
289+
gb_type="torch.nn.Module with a non-function custom __getattr__",
290+
context=f"var_getattr {self} {name}",
291+
explanation=(
292+
"Dynamo detected a nn.Module object with a custom "
293+
"`__getattr__` method, but this method is not a standard "
294+
"Python function (e.g., it might be implemented in C/C++). "
295+
"Dynamo cannot currently trace into such non-standard "
296+
"`__getattr__` methods."
297+
),
298+
hints=[
299+
"Avoid using objects with non-standard __getattr__ methods "
300+
"within the compiled region. If possible, implement "
301+
"__getattr__ as a standard Python function.",
302+
*graph_break_hints.SUPPORTABLE,
303+
],
304+
)
270305

271306
options = {"source": AttrSource(obj_source, "__getattr__")}
272307
return variables.UserMethodVariable(getattr_fn, self, **options).call_function(
@@ -284,7 +319,14 @@ def var_getattr(self, tx: "InstructionTranslator", name):
284319
all_class_attribute_names.update(x.__dict__.keys())
285320

286321
if not self.source:
287-
unimplemented("GETATTR with no source")
322+
unimplemented_v2(
323+
gb_type="getattr with no source",
324+
context=f"var_getattr {self} {name}",
325+
explanation="Dynamo does not know how to access an attribute "
326+
"on an `nn.Module` instance that lacks a source. This is "
327+
"usually an internal error in Dynamo.",
328+
hints=[*graph_break_hints.DYNAMO_BUG],
329+
)
288330

289331
if name == "__dict__":
290332
return variables.GetAttrVariable(self, name, source=source)
@@ -565,7 +607,13 @@ def assert_all_args_kwargs_const():
565607
if not all(
566608
x.is_python_constant() for x in itertools.chain(args, kwargs.values())
567609
):
568-
unimplemented(f"non-const NNModule method {name}")
610+
unimplemented_v2(
611+
gb_type="non-const argument in nn.Module method",
612+
context=f"call_method: {self} {name} {args} {kwargs}",
613+
explanation="Dynamo does not support calling "
614+
f"method `{name}` of ``nn.Module`` {module} with non-constant arguments.",
615+
hints=[],
616+
)
569617

570618
def get_kwargs(*names):
571619
assert_all_args_kwargs_const()
@@ -756,7 +804,13 @@ def gen_source(source, name):
756804
elif args[0].is_python_constant():
757805
key = args[0].as_python_constant()
758806
else:
759-
unimplemented(f"getitem on NNModuleVariable with key {args[0]}")
807+
unimplemented_v2(
808+
gb_type="Unsupported key type for nn.Module.__getitem__",
809+
context=f"call_method: {self} {name} {args} {kwargs}",
810+
explanation="Dynamo does not support getitem on "
811+
"`nn.Module` with non-constant key.",
812+
hints=[],
813+
)
760814

761815
submod = module[key]
762816
return tx.output.register_attr_or_module(
@@ -991,7 +1045,15 @@ def call_method(
9911045
hasattr(method, "__code__")
9921046
and id(method.__code__) in self._nn_module_method_ids()
9931047
):
994-
unimplemented(f"UnspecializedNNModuleVariable missing {name}")
1048+
unimplemented_v2(
1049+
gb_type="UnspecializedNNModuleVariable missing method",
1050+
context=f"call_method: {self} {name} {args} {kwargs}",
1051+
explanation=f"Dynamo does not support tracing method {name} of nn.Module {self.value}",
1052+
hints=[
1053+
"Dynamo does not really define unspecialized nn.Module very well.",
1054+
*graph_break_hints.DIFFICULT,
1055+
],
1056+
)
9951057

9961058
# "_parameters" in self.value.__dict__ checks that module is initialized
9971059
if name == "__setattr__" and "_parameters" in self.value.__dict__:

0 commit comments

Comments
 (0)
0