8000 Update on "[dynamo][not ready] polyfill infra for classes" · pytorch/pytorch@7bd428d · GitHub
[go: up one dir, main page]

Skip to content

Commit 7bd428d

Browse files
committed
Update on "[dynamo][not ready] polyfill infra for classes"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
2 parents f770084 + b80a86b commit 7bd428d

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torch/_dynamo/side_effects.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def get_variable_cls(self, user_cls):
297297

298298
from .variables.ctx_manager import GenericContextWrappingVariable
299299
from .variables.torch_function import TorchFunctionModeVariable
300+
from .variables.user_defined import is_forbidden_context_manager
300301

301302
variable_cls: type[
302303
variables.UserDefinedObjectVariable
@@ -305,9 +306,13 @@ def get_variable_cls(self, user_cls):
305306
user_cls, TorchFunctionMode
306307
) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls):
307308
variable_cls = TorchFunctionModeVariable
308-
elif hasattr(user_cls, "__enter__") and hasattr(user_cls, "__exit__"):
309+
elif (
310+
hasattr(user_cls, "__enter__")
311+
and hasattr(user_cls, "__exit__")
312+
and not is_forbidden_context_manager(user_cls)
313+
):
309314
variable_cls = GenericContextWrappingVariable
310-
if issubclass(user_cls, torch.nn.Module):
315+
elif issubclass(user_cls, torch.nn.Module):
311316
variable_cls = variables.UnspecializedNNModuleVariable
312317
elif issubclass(user_cls, (dict, collections.OrderedDict)):
313318
variable_cls = variables.UserDefinedDictVariable

0 commit comments

Comments
 (0)
0