diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index ea044a96a82134..546fcb803126ff 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -409,6 +409,33 @@ def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable: return wrapper +def substitute_class(original_class, supports_reconstruction=True): + """ + Register a polyfill handler for a class, usually a C++ class from the C extension, to be + used in place of the original class when inlining the original class in the graph. + + .. note:: + + The polyfill handler is only used when inlining the original class. It is not used when + the original class is called directly. In the eager mode, the decorated class calls + the performant C++ class rather than the polyfill handler. + """ + + def inner(traceable_class): + assert hasattr(traceable_class, "convert_to_traceable") + if supports_reconstruction: + assert hasattr(traceable_class, "convert_to_original") + traceable_class.__global_name__ = f"___{traceable_class.__module__}_{traceable_class.__name__}___" + + from torch._dynamo.trace_rules import _polyfilled_class_mapping + _polyfilled_class_mapping[original_class] = traceable_class + + _polyfilled_class_mapping + return traceable_class + + return inner + + # Helper function to flatten a tensor subclass and apply a function to # all inner tensors that match the outer dim. Used to reduce duplication # across the various marking APIs. diff --git a/torch/_dynamo/polyfills/functools.py b/torch/_dynamo/polyfills/functools.py index 05976618f69412..2f1c94895b88dc 100644 --- a/torch/_dynamo/polyfills/functools.py +++ b/torch/_dynamo/polyfills/functools.py @@ -6,7 +6,7 @@ from collections.abc import Iterable from typing import Callable, TypeVar -from ..decorators import substitute_in_graph +from ..decorators import substitute_class, substitute_in_graph __all__ = ["reduce"] @@ -45,3 +45,46 @@ def reduce( value = function(value, element) return value + + +@substitute_class(functools.partial, supports_reconstruction=True) +class partial: + """New function with partial application of the given arguments + and keywords. + """ + + __slots__ = "func", "args", "keywords", "__dict__", "__weakref__" + + def __new__(cls, func, /, *args, **keywords): + if not callable(func): + raise TypeError("the first argument must be callable") + + if isinstance(func, partial): + args = func.args + args + keywords = {**func.keywords, **keywords} + func = func.func + + self = super(partial, cls).__new__(cls) + + self.func = func + self.args = args + self.keywords = keywords + return self + + def __call__(self, /, *args, **keywords): + keywords = {**self.keywords, **keywords} + return self.func(*self.args, *args, **keywords) + + @staticmethod + def convert_to_traceable(original_value): + assert isinstance(original_value, functools.partial) + return partial( + original_value.func, *original_value.args, **original_value.keywords + ) + + @staticmethod + def convert_to_original(value): + assert isinstance(value, partial) + return functools.partial( + value.func, *value.args, **value.keywords + ) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d06ffccfc4c92b..59b2581b7e5e9f 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3025,6 +3025,9 @@ def _polyfilled_function_ids() -> set[int]: return set() +_polyfilled_class_mapping = {} + + @FunctionIdSet def _numpy_function_ids() -> dict[int, str]: unsupported_funcs = { diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index ba7a10267e2e1a..4d1534d0122de0 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -118,6 +118,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable from .user_defined import ( MutableMappingVariable, + PolyFilledUserDefinedClassVariable, RemovableHandleVariable, UserDefinedClassVariable, UserDefinedDictVariable, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f43bb435c08ca7..ec6ad56c9edb86 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -144,7 +144,6 @@ CollectionsNamedTupleFunction, CollectiveFunctionRewriteVariable, CreateTMADescriptorVariable, - FunctoolsPartialVariable, FunctoolsWrapsVariable, TritonKernelVariable, UserFunctionVariable, @@ -222,6 +221,7 @@ FrozenDataClassVariable, KeyedJaggedTensorVariable, MutableMappingVariable, + PolyFilledUserDefinedClassVariable, SourcelessGraphModuleVariable, UserDefinedClassVariable, UserDefinedDictVariable, @@ -670,32 +670,36 @@ def build_key_value(i, k, v): return build_checkpoint_variable(source=self.source) elif is_invoke_subgraph(value): return build_invoke_subgraph_variable(source=self.source) - elif isinstance(value, functools.partial): - func_src = AttrSource(self.get_source(), "func") - func_obj = VariableBuilder(self.tx, func_src)(value.func) - - args = [] - args_source = AttrSource(self.get_source(), "args") - for i, arg in enumerate(value.args): - args.append( - VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) - ) - - keywords = {} - keywords_source = AttrSource(self.get_source(), "keywords") - for k, v in value.keywords.items(): - if not ConstantVariable.is_literal(k): - unimplemented("functools.partial with non-literal keyword") - keywords[k] = VariableBuilder( - self.tx, DictGetItemSource(keywords_source, k) - )(v) - - install_guard( - self.get_source().make_guard(GuardBuilder.TYPE_MATCH), - keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH), - args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), - ) - return FunctoolsPartialVariable(func_obj, args, keywords) + # elif isinstance(value, functools.partial): + # self.install_guards(GuardBuilder.TYPE_MATCH) + # new_value = polyfills.functools.partial(value.func, value.args, value.keywords) + # return UserDefinedObjectVariable(new_value, source=self.source) + # # func_src = AttrSource(self.get_source(), "func") + # # func_obj = VariableBuilder(self.tx, func_src)(value.func) + + # # args = [] + # # args_source = AttrSource(self.get_source(), "args") + # # for i, arg in enumerate(value.args): + # # args.append( + # # VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) + # # ) + + # # keywords = {} + # # keywords_source = AttrSource(self.get_source(), "keywords") + # # for k, v in value.keywords.items(): + # # if not ConstantVariable.is_literal(k): + # # unimplemented("functools.partial with non-literal keyword") + # # keywords[k] = VariableBuilder( + # # self.tx, DictGetItemSource(keywords_source, k) + # # )(v) + + # # install_guard( + # # self.get_source().make_guard(GuardBuilder.TYPE_MATCH), + # # keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH), + # # args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), + # # ) + # # breakpoint() + # # return FunctoolsPartialVariable(func_obj, args, keywords) elif is_typing(value): # typing.List, typing.Mapping, etc. self.install_guards(GuardBuilder.ID_MATCH) @@ -1097,6 +1101,15 @@ def build_key_value(i, k, v): # unlikely to change, so its ok to skip the guard here. return MethodWrapperVariable(value) elif issubclass(type(value), type): + if trace_class := trace_rules._polyfilled_class_mapping.get(value): + return PolyFilledUserDefinedClassVariable.create( + tx=self.tx, + orig_class=value, + orig_source=self.source, + trace_class=trace_class + ) + + if value in ( torch.utils.hooks.BackwardHook, torch.nn.Parameter, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 15f384eeeaad71..45705c45a77336 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1603,7 +1603,7 @@ def call_len(self, tx: "InstructionTranslator", *args, **kwargs): def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__getitem__", args[1:], kwargs) - def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type): + def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type_vt): try: arg_type = arg.python_type() except NotImplementedError: @@ -1611,7 +1611,7 @@ def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type): f"isinstance({arg}, {isinstance_type}): can't determine type of {arg}" ) - isinstance_type = isinstance_type.as_python_constant() + isinstance_type = isinstance_type_vt.as_python_constant() if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: @@ -1652,6 +1652,9 @@ def check_type(ty): # handle __instancecheck__ defined in user class if ( isinstance(arg, variables.UserDefinedObjectVariable) + and not isinstance( + isinstance_type_vt, variables.PolyFilledUserDefinedClassVariable + ) and "__instancecheck__" in isinstance_type.__class__.__dict__ ): return variables.ConstantVariable.create( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 0ea1583af572b6..e8f07a8962a8e6 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -13,7 +13,7 @@ import types import warnings import weakref -from typing import TYPE_CHECKING +from typing import Generic, TYPE_CHECKING, Callable, Any from typing_extensions import is_typeddict import torch._dynamo.config @@ -440,16 +440,24 @@ def call_function( elif self.value is weakref.ref: return variables.WeakRefVariable(args[0]) elif self.value is functools.partial: - if not args: - unimplemented("functools.partial malformed") - # The first arg, a callable (the ctor below will assert on types) - fn = args[0] - rest_args = args[1:] - # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the - # args and keywords - return variables.functions.FunctoolsPartialVariable( - fn, args=rest_args, keywords=kwargs + new_cls_vt = variables.UserDefinedClassVariable(polyfills.functools.partial) + var = tx.output.side_effects.track_object_new_from_user_defined_class( + new_cls_vt ) + var.call_method(tx, "__init__", args, kwargs) + return var + # new_value = functools.partial(identity) + # return UserDefinedObjectVariable(new_value, ) + # if not args: + # unimplemented("functools.partial malformed") + # # The first arg, a callable (the ctor below will assert on types) + # fn = args[0] + # rest_args = args[1:] + # # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the + # # args and keywords + # return variables.functions.FunctoolsPartialVariable( + # fn, args=rest_args, keywords=kwargs + # ) elif self.value is warnings.catch_warnings and not args: return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs) elif self.value is torch.cuda.device and not kwargs and len(args) == 1: @@ -724,12 +732,16 @@ def __init__( ) -> None: super().__init__(**kwargs) self.value = value + if is_polyfilled: + assert value_type is not None, "polyfill must provide the original type" self.value_type = value_type or type(value) - assert type(value) is self.value_type + if not is_polyfilled: + assert type(value) is self.value_type # This is used with __new__, when the new object is sourceless but the user class can be sourceful. self.cls_source = cls_source if cls_source is None and self.source is not None: self.cls_source = TypeSource(self.source) + self.is_polyfilled = is_polyfilled # These attributes are used to reconstruct the user defined object. The # pseudo code looks like this. Builtin C __new__ do not support kwargs, @@ -1576,3 +1588,107 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke class RandomVariable(UserDefinedObjectVariable): pass + + +class PolyFilledUserDefinedClassVariable(VariableTracker): + @staticmethod + def create(tx, orig_class, orig_source, trace_class): + trace_source = AttrSource(tx.import_source(trace_class.__module__), trace_class.__name__) + trace_vt = UserDefinedClassVariable(trace_class, source=trace_source) + + return PolyFilledUserDefinedClassVariable(orig_class, trace_class, trace_vt, source=orig_source) + + def __init__(self, original_class, traceable_class, traceable_class_vt, **kwargs) -> None: + self.original_class = original_class + self.traceable_class = traceable_class + self.traceable_class_vt = traceable_class_vt + # # NB - The `value` is changed to the polyfilled class. From here, the + # # polyfilled class is used to create the object. + # self.value = traceable_class + + def as_python_constant(self): + return self.original_class + + def as_proxy(self): + return self.original_class + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + obj = self.traceable_class_vt.call_function(tx, args, kwargs) + assert isinstance(obj, UserDefinedObjectVariable) + # return obj + global_name = self.traceable_class.__global_name__ + installed_global_name = tx.output.install_global_by_id(global_name, self.traceable_class.convert_to_original) + return PolyFilledUserDefinedObjectVariable(obj, self.original_class, self.traceable_class, installed_global_name, mutation_type=obj.mutation_type) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + return self.traceable_class_vt.var_getattr(tx, name) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.traceable_class_vt.call_method(tx, name, args, kwargs) + + +class PolyFilledUserDefinedObjectVariable(VariableTracker): + def __init__(self, udf_vt, original_class, traceable_class, installed_global_name, **kwargs) -> None: + super().__init__(**kwargs) + self.udf_vt = udf_vt + self.original_class = original_class + self.traceable_class = traceable_class + self.installed_global_name = installed_global_name + + def reconstruct(self, codegen): + if self.udf_vt not in codegen.tempvars: + unimplemented("Incorrect reconstruction for polyfilled object") + + # We have the tempvar for the instance of traceable class. For + # reconstructing to the original class, call traceable_class + # convert_to_original method. + + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_global(self.installed_global_name, add=True), + codegen.create_load(codegen.tempvars[self.udf_vt]), + ] + ) + ) + codegen.extend_output(create_call_function(1, False)) + + + def python_type(self): + # NB - This is intentional. For tracing purpose, we want to ensure that + # the class is considered original class. If not, we will have wrong + # conditionals on isinstance(value, class_type) + return self.original_class + + +def _forward_to_udf_vt( + name: str, +) -> Callable[[PolyFilledUserDefinedObjectVariable, Any, Any], Any]: + @functools.wraps(getattr(UserDefinedObjectVariable, name)) + def forward_to_udf_vt( + self: PolyFilledUserDefinedObjectVariable, *args: Any, **kwargs: Any + ) -> Any: + return getattr(self.udf_vt, name)(*args, **kwargs) + + return forward_to_udf_vt + + +def _populate() -> None: + for name, value in UserDefinedObjectVariable.__dict__.items(): + if name not in PolyFilledUserDefinedObjectVariable.__dict__: + if callable(value): + setattr(PolyFilledUserDefinedObjectVariable, name, _forward_to_udf_vt(name)) + + +_populate()