8000 [dynamo][not ready] polyfill infra for classes · pytorch/pytorch@535ed4a · GitHub
[go: up one dir, main page]

Skip to content

Commit 535ed4a

Browse files
committed
[dynamo][not ready] polyfill infra for classes
ghstack-source-id: 8957f6c Pull Request resolved: #146678
1 parent 0c6a000 commit 535ed4a

File tree

7 files changed

+250
-42
lines changed

7 files changed

+250
-42
lines changed

torch/_dynamo/decorators.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,33 @@ def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable:
409409
return wrapper
410410

411411

412+
def substitute_class(original_class, supports_reconstruction=True):
413+
"""
414+
Register a polyfill handler for a class, usually a C++ class from the C extension, to be
415+
used in place of the original class when inlining the original class in the graph.
416+
417+
.. note::
418+
419+
The polyfill handler is only used when inlining the original class. It is not used when
420+
the original class is called directly. In the eager mode, the decorated class calls
421+
the performant C++ class rather than the polyfill handler.
422+
"""
423+
424+
def inner(traceable_class):
425+
assert hasattr(traceable_class, "convert_to_traceable")
426+
if supports_reconstruction:
427+
assert hasattr(traceable_class, "convert_to_original")
428+
traceable_class.__global_name__ = f"___{traceable_class.__module__}_{traceable_class.__name__}___"
429+
430+
from torch._dynamo.trace_rules import _polyfilled_class_mapping
431+
_polyfilled_class_mapping[original_class] = traceable_class
432+
433+
_polyfilled_class_mapping
434+
return traceable_class
435+
436+
return inner
437+
438+
412439
# Helper function to flatten a tensor subclass and apply a function to
413440
# all inner tensors that match the outer dim. Used to reduce duplication
414441
# across the various marking APIs.

torch/_dynamo/polyfills/functools.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Iterable
77
from typing import Callable, TypeVar
88

9-
from ..decorators import substitute_in_graph
9+
from ..decorators import substitute_class, substitute_in_graph
1010

1111

1212
__all__ = ["reduce"]
@@ -45,3 +45,46 @@ def reduce(
4545
value = function(value, element)
4646

4747
return value
48+
49+
50+
@substitute_class(functools.partial, supports_reconstruction=True)
51+
class partial:
52+
"""New function with partial application of the given arguments
53+
and keywords.
54+
"""
55+
56+
__slots__ = "func", "args", "keywords", "__dict__", "__weakref__"
57+
58+
def __new__(cls, func, /, *args, **keywords):
59+
if not callable(func):
60+
raise TypeError("the first argument must be callable")
61+
62+
if isinstance(func, partial):
63+
args = func.args + args
64+
keywords = {**func.keywords, **keywords}
65+
func = func.func
66+
67+
self = super(partial, cls).__new__(cls)
68+
69+
self.func = func
70+
self.args = args
71+
self.keywords = keywords
72+
return self
73+
74+
def __call__(self, /, *args, **keywords):
75+
keywords = {**self.keywords, **keywords}
76+
return self.func(*self.args, *args, **keywords)
77+
78+
@staticmethod
79+
def convert_to_traceable(original_value):
80+
assert isinstance(original_value, functools.partial)
81+
return partial(
82+
original_value.func, *original_value.args, **original_value.keywords
83+
)
84+
85+
@staticmethod
86+
def convert_to_original(value):
87+
assert isinstance(value, partial)
88+
return functools.partial(
89+
value.func, *value.args, **value.keywords
90+
)

torch/_dynamo/trace_rules.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,6 +3024,9 @@ def _polyfilled_function_ids() -> set[int]:
30243024
return set()
30253025

30263026

3027+
_polyfilled_class_mapping = {}
3028+
3029+
30273030
@FunctionIdSet
30283031
def _numpy_function_ids() -> dict[int, str]:
30293032
unsupported_funcs = {

torch/_dynamo/variables/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
117117
from .user_defined import (
118118
MutableMappingVariable,
119+
PolyFilledUserDefinedClassVariable,
119120
RemovableHandleVariable,
120121
UserDefinedClassVariable,
121122
UserDefinedDictVariable,

torch/_dynamo/variables/builder.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@
144144
CollectionsNamedTupleFunction,
145145
CollectiveFunctionRewriteVariable,
146146
CreateTMADescriptorVariable,
147-
FunctoolsPartialVariable,
148147
FunctoolsWrapsVariable,
149148
TritonKernelVariable,
150149
UserFunctionVariable,
@@ -222,6 +221,7 @@
222221
FrozenDataClassVariable,
223222
KeyedJaggedTensorVariable,
224223
MutableMappingVariable,
224+
PolyFilledUserDefinedClassVariable,
225225
SourcelessGraphModuleVariable,
226226
UserDefinedClassVariable,
227227
UserDefinedDictVariable,
@@ -670,32 +670,36 @@ def build_key_value(i, k, v):
670670
return build_checkpoint_variable(source=self.source)
671671
elif is_invoke_subgraph(value):
672672
return build_invoke_subgraph_variable(source=self.source)
673-
elif isinstance(value, functools.partial):
674-
func_src = AttrSource(self.get_source(), "func")
675-
func_obj = VariableBuilder(self.tx, func_src)(value.func)
676-
677-
args = []
678-
args_source = AttrSource(self.get_source(), "args")
679-
for i, arg in enumerate(value.args):
680-
args.append(
681-
VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
682-
)
683-
684-
keywords = {}
685-
keywords_source = AttrSource(self.get_source(), "keywords")
686-
for k, v in value.keywords.items():
687-
if not ConstantVariable.is_literal(k):
688-
unimplemented("functools.partial with non-literal keyword")
689-
keywords[k] = VariableBuilder(
690-
self.tx, DictGetItemSource(keywords_source, k)
691-
)(v)
692-
693-
install_guard(
694-
self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
695-
keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH),
696-
args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
697-
)
698-
return FunctoolsPartialVariable(func_obj, args, keywords)
673+
# elif isinstance(value, functools.partial):
674+
# self.install_guards(GuardBuilder.TYPE_MATCH)
675+
# new_value = polyfills.functools.partial(value.func, value.args, value.keywords)
676+
# return UserDefinedObjectVariable(new_value, source=self.source)
677+
# # func_src = AttrSource(self.get_source(), "func")
678+
# # func_obj = VariableBuilder(self.tx, func_src)(value.func)
679+
680+
# # args = []
681+
# # args_source = AttrSource(self.get_source(), "args")
682+
# # for i, arg in enumerate(value.args):
683+
# # args.append(
684+
# # VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
685+
# # )
686+
687+
# # keywords = {}
688+
# # keywords_source = AttrSource(self.get_source(), "keywords")
689+
# # for k, v in value.keywords.items():
690+
# # if not ConstantVariable.is_literal(k):
691+
# # unimplemented("functools.partial with non-literal keyword")
692+
# # keywords[k] = VariableBuilder(
693+
# # self.tx, DictGetItemSource(keywords_source, k)
694+
# # )(v)
695+
696+
# # install_guard(
697+
# # self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
698+
# # keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH),
699+
# # args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
700+
# # )
701+
# # breakpoint()
702+
# # return FunctoolsPartialVariable(func_obj, args, keywords)
699703
elif is_typing(value):
700704
# typing.List, typing.Mapping, etc.
701705
self.install_guards(GuardBuilder.ID_MATCH)
@@ -1097,6 +1101,15 @@ def build_key_value(i, k, v):
10971101
# unlikely to change, so its ok to skip the guard here.
10981102
return MethodWrapperVariable(value)
10991103
elif issubclass(type(value), type):
1104+
if trace_class := trace_rules._polyfilled_class_mapping.get(value):
1105+
return PolyFilledUserDefinedClassVariable.create(
1106+
tx=self.tx,
1107+
orig_class=value,
1108+
orig_source=self.source,
1109+
trace_class=trace_class
1110+
)
1111+
1112+
11001113
if value in (
11011114
torch.utils.hooks.BackwardHook,
11021115
torch.nn.Parameter,

torch/_dynamo/variables/builtin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,15 +1591,15 @@ def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
15911591
def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
15921592
return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
15931593

1594-
def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type):
1594+
def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type_vt):
15951595
try:
15961596
arg_type = arg.python_type()
15971597
except NotImplementedError:
15981598
unimplemented(
15991599
f"isinstance({arg}, {isinstance_type}): can't determine type of {arg}"
16001600
)
16011601

1602-
isinstance_type = isinstance_type.as_python_constant()
1602+
isinstance_type = isinstance_type_vt.as_python_constant()
16031603

16041604
if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
16051605

@@ -1640,6 +1640,9 @@ def check_type(ty):
16401640
# handle __instancecheck__ defined in user class
16411641
if (
16421642
isinstance(arg, variables.UserDefinedObjectVariable)
1643+
and not isinstance(
1644+
isinstance_type_vt, variables.PolyFilledUserDefinedClassVariable
1645+
)
16431646
and "__instancecheck__" in isinstance_type.__class__.__dict__
16441647
):
16451648
return variables.ConstantVariable.create(

0 commit comments

Comments
 (0)
0