8000 [dynamo][not ready] polyfill infra for classes · pytorch/pytorch@cb107f9 · GitHub
[go: up one dir, main page]

Skip to content

Commit cb107f9

Browse files
committed
[dynamo][not ready] polyfill infra for classes
ghstack-source-id: 17a0e0e Pull Request resolved: #146678
1 parent f366528 commit cb107f9

File tree

7 files changed

+247
-41
lines changed

7 files changed

+247
-41
lines changed

torch/_dynamo/decorators.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,33 @@ def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable:
409409
return wrapper
410410

411411

412+
def substitute_class(original_class, supports_reconstruction=True):
413+
"""
414+
Register a polyfill handler for a class, usually a C++ class from the C extension, to be
415+
used in place of the original class when inlining the original class in the graph.
416+
417+
.. note::
418+
419+
The polyfill handler is only used when inlining the original class. It is not used when
420+
the original class is called directly. In the eager mode, the decorated class calls
421+
the performant C++ class rather than the polyfill handler.
422+
"""
423+
424+
def inner(traceable_class):
425+
assert hasattr(traceable_class, "convert_to_traceable")
426+
if supports_reconstruction:
427+
assert hasattr(traceable_class, "convert_to_original")
428+
traceable_class.__global_name__ = f"___{traceable_class.__module__}_{traceable_class.__name__}___"
429+
430+
from torch._dynamo.trace_rules import _polyfilled_class_mapping
431+
_polyfilled_class_mapping[original_class] = traceable_class
432+
433+
_polyfilled_class_mapping
434+
return traceable_class
435+
436+
return inner
437+
438+
412439
# Helper function to flatten a tensor subclass and apply a function to
413440
# all inner tensors that match the outer dim. Used to reduce duplication
414441
# across the various marking APIs.

torch/_dynamo/polyfills/functools.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Iterable
77
from typing import Callable, TypeVar
88

9-
from ..decorators import substitute_in_graph
9+
from ..decorators import substitute_class, substitute_in_graph
1010

1111

1212
__all__ = ["reduce"]
@@ -45,3 +45,46 @@ def reduce(
4545
value = function(value, element)
4646

4747
return value
48+
49+
50+
@substitute_class(functools.partial, supports_reconstruction=True)
51+
class partial:
52+
"""New function with partial application of the given arguments
53+
and keywords.
54+
"""
55+
56+
__slots__ = "func", "args", "keywords", "__dict__", "__weakref__"
57+
58+
def __new__(cls, func, /, *args, **keywords):
59+
if not callable(func):
60+
raise TypeError("the first argument must be callable")
61+
62+
if isinstance(func, partial):
63+
args = func.args + args
64+
keywords = {**func.keywords, **keywords}
65+
func = func.func
66+
67+
self = super(partial, cls).__new__(cls)
68+
69+
self.func = func
70+
self.args = args
71+
self.keywords = keywords
72+
return self
73+
74+
def __call__(self, /, *args, **keywords):
75+
keywords = {**self.keywords, **keywords}
76+
return self.func(*self.args, *args, **keywords)
77+
78+
@staticmethod
79+
def convert_to_traceable(original_value):
80+
assert isinstance(original_value, functools.partial)
81+
return partial(
82+
original_value.func, *original_value.args, **original_value.keywords
83+
)
84+
85+
@staticmethod
86+
def convert_to_original(value):
87+
assert isinstance(value, partial)
88+
return functools.partial(
89+
value.func, *value.args, **value.keywords
90+
)

torch/_dynamo/trace_rules.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3025,6 +3025,9 @@ def _polyfilled_function_ids() -> set[int]:
30253025
return set()
30263026

30273027

3028+
_polyfilled_class_mapping = {}
3029+
3030+
30283031
@FunctionIdSet
30293032
def _numpy_function_ids() -> dict[int, str]:
30303033
unsupported_funcs = {

torch/_dynamo/variables/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
119119
from .user_defined import (
120120
MutableMappingVariable,
121+
PolyFilledUserDefinedClassVariable,
121122
RemovableHandleVariable,
122123
UserDefinedClassVariable,
123124
UserDefinedDictVariable,

torch/_dynamo/variables/builder.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@
144144
CollectionsNamedTupleFunction,
145145
CollectiveFunctionRewriteVariable,
146146
CreateTMADescriptorVariable,
147-
FunctoolsPartialVariable,
148147
FunctoolsWrapsVariable,
149148
TritonKernelVariable,
150149
UserFunctionVariable,
@@ -222,6 +221,7 @@
222221
FrozenDataClassVariable,
223222
KeyedJaggedTensorVariable,
224223
MutableMappingVariable,
224+
PolyFilledUserDefinedClassVariable,
225225
SourcelessGraphModuleVariable,
226226
UserDefinedClassVariable,
227227
UserDefinedDictVariable,
@@ -670,32 +670,36 @@ def build_key_value(i, k, v):
670670
return build_checkpoint_variable(source=self.source)
671671
elif is_invoke_subgraph(value):
672672
return build_invoke_subgraph_variable(source=self.source)
673-
elif isinstance(value, functools.partial):
674-
func_src = AttrSource(self.get_source(), "func")
675-
func_obj = VariableBuilder(self.tx, func_src)(value.func)
676-
677-
args = []
678-
args_source = AttrSource(self.get_source(), "args")
679-
for i, arg in enumerate(value.args):
680-
args.append(
681-
VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
682-
)
683-
684-
keywords = {}
685-
keywords_source = AttrSource(self.get_source(), "keywords")
686-
for k, v in value.keywords.items():
687-
if not ConstantVariable.is_literal(k):
688-
unimplemented("functools.partial with non-literal keyword")
689-
keywords[k] = VariableBuilder(
690-
self.tx, DictGetItemSource(keywords_source, k)
691-
)(v)
692-
693-
install_guard(
694-
self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
695-
keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH),
696-
args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
697-
)
698-
return FunctoolsPartialVariable(func_obj, args, keywords)
673+
# elif isinstance(value, functools.partial):
674+
# self.install_guards(GuardBuilder.TYPE_MATCH)
675+
# new_value = polyfills.functools.partial(value.func, value.args, value.keywords)
676+
# return UserDefinedObjectVariable(new_value, source=self.source)
677+
# # func_src = AttrSource(self.get_source(), "func")
678+
# # func_obj = VariableBuilder(self.tx, func_src)(value.func)
679+
680+
# # args = []
681+
# # args_source = AttrSource(self.get_source(), "args")
682+
# # for i, arg in enumerate(value.args):
683+
# # args.append(
684+
# # VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
685+
# # )
686+
687+
# # keywords = {}
688+
# # keywords_source = AttrSource(self.get_source(), "keywords")
689+
# # for k, v in value.keywords.items():
690+
# # if not ConstantVariable.is_literal(k):
691+
# # unimplemented("functools.partial with non-literal keyword")
692+
# # keywords[k] = VariableBuilder(
693+
# # self.tx, DictGetItemSource(keywords_source, k)
694+
# # )(v)
695+
696+
# # install_guard(
697+
# # self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
698+
# # keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH),
699+
# # args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
700+
# # )
701+
# # breakpoint()
702+
# # return FunctoolsPartialVariable(func_obj, args, keywords)
699703
elif is_typing(value):
700704
# typing.List, typing.Mapping, etc.
701705
self.install_guards(GuardBuilder.ID_MATCH)
@@ -1097,6 +1101,15 @@ def build_key_value(i, k, v):
10971101
# unlikely to change, so its ok to skip the guard here.
10981102
return MethodWrapperVariable(value)
10991103
elif issubclass(type(value), type):
1104+
if trace_class := trace_rules._polyfilled_class_mapping.get(value):
1105+
return PolyFilledUserDefinedClassVariable.create(
1106+
tx=self.tx,
1107+
orig_class=value,
1108+
orig_source=self.source,
1109+
trace_class=trace_class
1110+
)
1111+
1112+
11001113
if value in (
11011114
torch.utils.hooks.BackwardHook,
11021115
torch.nn.Parameter,

torch/_dynamo/variables/builtin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,15 +1603,15 @@ def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
16031603
def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
16041604
return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
16051605

1606-
def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type):
1606+
def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type_vt):
16071607
try:
16081608
arg_type = arg.python_type()
16091609
except NotImplementedError:
16101610
unimplemented(
16111611
f"isinstance({arg}, {isinstance_type}): can't determine type of {arg}"
16121612
)
16131613

1614-
isinstance_type = isinstance_type.as_python_constant()
1614+
isinstance_type = isinstance_type_vt.as_python_constant()
16151615

16161616
if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
16171617

@@ -1652,6 +1652,9 @@ def check_type(ty):
16521652
# handle __instancecheck__ defined in user class
16531653
if (
16541654
isinstance(arg, variables.UserDefinedObjectVariable)
1655+
and not isinstance(
1656+
isinstance_type_vt, variables.PolyFilledUserDefinedClassVariable
1657+
)
16551658
and "__instancecheck__" in isinstance_type.__class__.__dict__
16561659
):
16571660
return variables.ConstantVariable.create(

torch/_dynamo/variables/user_defined.py

Lines changed: 127 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import types
1414
import warnings
1515
import weakref
16-
from typing import TYPE_CHECKING
16+
from typing import Generic, TYPE_CHECKING, Callable, Any
1717
from typing_extensions import is_typeddict
1818

1919
import torch._dynamo.config
@@ -429,16 +429,24 @@ def call_function(
429429
elif self.value is weakref.ref:
430430
return variables.WeakRefVariable(args[0])
431431
elif self.value is functools.partial:
432-
if not args:
433-
unimplemented("functools.partial malformed")
434-
# The first arg, a callable (the ctor below will assert on types)
435-
fn = args[0]
436-
rest_args = args[1:]
437-
# guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the
438-
# args and keywords
439-
return variables.functions.FunctoolsPartialVariable(
440-
fn, args=rest_args, keywords=kwargs
432+
new_cls_vt = variables.UserDefinedClassVariable(polyfills.functools.partial)
433+
var = tx.output.side_effects.track_object_new_from_user_defined_class(
434+
new_cls_vt
441435
)
436+
var.call_method(tx, "__init__", args, kwargs)
437+
return var
438+
# new_value = functools.partial(identity)
439+
# return UserDefinedObjectVariable(new_value, )
440+
# if not args:
441+
# unimplemented("functools.partial malformed")
442+
# # The first arg, a callable (the ctor below will assert on types)
443+
# fn = args[0]
444+
# rest_args = args[1:]
445+
# # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the
446+
# # args and keywords
447+
# return variables.functions.FunctoolsPartialVariable(
448+
# fn, args=rest_args, keywords=kwargs
449+
# )
442450
elif self.value is warnings.catch_warnings and not args:
443451
return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs)
444452
elif self.value is torch.cuda.device and not kwargs and len(args) == 1:
@@ -728,12 +736,16 @@ def __init__(
728736
) -> None:
729737
super().__init__(**kwargs)
730738
self.value = value
739+
if is_polyfilled:
740+
assert value_type is not None, "polyfill must provide the original type"
731741
self.value_type = value_type or type(value)
732-
assert type(value) is self.value_type
742+
if not is_polyfilled:
743+
assert type(value) is self.value_type
733744
# This is used with __new__, when the new object is sourceless but the user class can be sourceful.
734745
self.cls_source = cls_source
735746
if cls_source is None and self.source is not None:
736747
self.cls_source = TypeSource(self.source)
748+
self.is_polyfilled = is_polyfilled
737749

738750
# These attributes are used to reconstruct the user defined object. The
739751
# pseudo code looks like this. Builtin C __new__ do not support kwargs,
@@ -1580,3 +1592,107 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
15801592

15811593
class RandomVariable(UserDefinedObjectVariable):
15821594
pass
1595+
1596+
1597+
class PolyFilledUserDefinedClassVariable(VariableTracker):
1598+
@staticmethod
1599+
def create(tx, orig_class, orig_source, trace_class):
1600+
trace_source = AttrSource(tx.import_source(trace_class.__module__), trace_class.__name__)
1601+
trace_vt = UserDefinedClassVariable(trace_class, source=trace_source)
1602+
1603+
return PolyFilledUserDefinedClassVariable(orig_class, trace_class, trace_vt, source=orig_source)
1604+
1605+
def __init__(self, original_class, traceable_class, traceable_class_vt, **kwargs) -> None:
1606+
self.original_class = original_class
1607+
self.traceable_class = traceable_class
1608+
self.traceable_class_vt = traceable_class_vt
1609+
# # NB - The `value` is changed to the polyfilled class. From here, the
1610+
# # polyfilled class is used to create the object.
1611+
# self.value = traceable_class
1612+
1613+
def as_python_constant(self):
1614+
return self.original_class
1615+
1616+
def as_proxy(self):
1617+
return self.original_class
1618+
1619+
def call_function(
1620+
self,
1621+
tx: "InstructionTranslator",
1622+
args: "list[VariableTracker]",
1623+
kwargs: "dict[str, VariableTracker]",
1624+
) -> "VariableTracker":
1625+
obj = self.traceable_class_vt.call_function(tx, args, kwargs)
1626+
assert isinstance(obj, UserDefinedObjectVariable)
1627+
# return obj
1628+
global_name = self.traceable_class.__global_name__
1629+
installed_global_name = tx.output.install_global_by_id(global_name, self.traceable_class.convert_to_original)
1630+
return PolyFilledUserDefinedObjectVariable(obj, self.original_class, self.traceable_class, installed_global_name, mutation_type=obj.mutation_type)
1631+
1632+
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
1633+
return self.traceable_class_vt.var_getattr(tx, name)
1634+
1635+
def call_method(
1636+
self,
1637+
tx,
1638+
name,
1639+
args: "list[VariableTracker]",
1640+
kwargs: "dict[str, VariableTracker]",
1641+
) -> "VariableTracker":
1642+
return self.traceable_class_vt.call_method(tx, name, args, kwargs)
1643+
1644+
1645+
class PolyFilledUserDefinedObjectVariable(VariableTracker):
1646+
def __init__(self, udf_vt, original_class, traceable_class, installed_global_name, **kwargs) -> None:
1647+
super().__init__(**kwargs)
1648+
self.udf_vt = udf_vt
1649+
self.original_class = original_class
1650+
self.traceable_class = traceable_class
1651+
self.installed_global_name = installed_global_name
1652+
1653+
def reconstruct(self, codegen):
1654+
if self.udf_vt not in codegen.tempvars:
1655+
unimplemented("Incorrect reconstruction for polyfilled object")
1656+
1657+
# We have the tempvar for the instance of traceable class. For
1658+
# reconstructing to the original class, call traceable_class
1659+
# convert_to_original method.
1660+
1661+
codegen.add_push_null(
1662+
lambda: codegen.extend_output(
1663+
[
1664+
codegen.create_load_global(self.installed_global_name, add=True),
1665+
codegen.create_load(codegen.tempvars[self.udf_vt]),
1666+
]
1667+
)
1668+
)
1669+
codegen.extend_output(create_call_function(1, False))
1670+
1671+
1672+
def python_type(self):
1673+
# NB - This is intentional. For tracing purpose, we want to ensure that
1674+
# the class is considered original class. If not, we will have wrong
1675+
# conditionals on isinstance(value, class_type)
1676+
return self.original_class
1677+
1678+
1679+
def _forward_to_udf_vt(
1680+
name: str,
1681+
) -> Callable[[PolyFilledUserDefinedObjectVariable, Any, Any], Any]:
1682+
@functools.wraps(getattr(UserDefinedObjectVariable, name))
1683+
def forward_to_udf_vt(
1684+
self: PolyFilledUserDefinedObjectVariable, *args: Any, **kwargs: Any
1685+
) -> Any:
1686+
return getattr(self.udf_vt, name)(*args, **kwargs)
1687+
1688+
return forward_to_udf_vt
1689+
1690+
1691+
def _populate() -> None:
1692+
for name, value in UserDefinedObjectVariable.__dict__.items():
1693+
if name not in PolyFilledUserDefinedObjectVariable.__dict__:
1694+
if callable(value):
1695+
setattr(PolyFilledUserDefinedObjectVariable, name, _forward_to_udf_vt(name))
1696+
1697+
1698+
_populate()

0 commit comments

Comments
 (0)
0