@@ -297,6 +297,7 @@ def get_variable_cls(self, user_cls):
297
297
298
298
from .variables .ctx_manager import GenericContextWrappingVariable
299
299
from .variables .torch_function import TorchFunctionModeVariable
300
+ from .variables .user_defined import is_forbidden_context_manager
300
301
301
302
variable_cls : type [
302
303
variables .UserDefinedObjectVariable
@@ -305,9 +306,13 @@ def get_variable_cls(self, user_cls):
305
306
user_cls , TorchFunctionMode
306
307
) and TorchFunctionModeVariable .is_supported_torch_function_mode (user_cls ):
307
308
variable_cls = TorchFunctionModeVariable
308
- elif hasattr (user_cls , "__enter__" ) and hasattr (user_cls , "__exit__" ):
309
+ elif (
310
+ hasattr (user_cls , "__enter__" )
311
+ and hasattr (user_cls , "__exit__" )
312
+ and not is_forbidden_context_manager (user_cls )
313
+ ):
309
314
variable_cls = GenericContextWrappingVariable
310
- if issubclass (user_cls , torch .nn .Module ):
315
+ elif issubclass (user_cls , torch .nn .Module ):
311
316
variable_cls = variables .UnspecializedNNModuleVariable
312
317
elif issubclass (user_cls , (dict , collections .OrderedDict )):
313
318
variable_cls = variables .UserDefinedDictVariable
0 commit comments