-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Make fx.node.map_arg() and .map_aggregate() generic #146248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
ba4ee7f
afd81cd
7674025
f15da1f
ae10f6c
1230b89
c63c17d
9dc40b5
5918536
e4aa732
441c089
dd5dc46
b52a69d
1160968
f248065
bbad690
c62e68b
e421e13
d13e435
41a2dec
cfc5f8b
8a2b953
dab69d3
e1bf862
2d5de70
1e18ebd
d0d7327
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
import types | ||
import warnings | ||
from collections.abc import Mapping, Sequence | ||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union | ||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union | ||
|
||
import torch | ||
from torch._C import _NodeBase | ||
|
@@ -56,6 +56,7 @@ | |
BaseArgumentTypes, | ||
] | ||
] | ||
ArgumentT = TypeVar("ArgumentT", bound=Argument) | ||
|
||
_legal_ops = dict.fromkeys( | ||
[ | ||
|
@@ -890,35 +891,41 @@ | |
|
||
|
||
@compatibility(is_backward_compatible=True) | ||
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: | ||
def map_arg(a: ArgumentT, fn: Callable[[Node], Argument]) -> ArgumentT: | ||
""" | ||
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. | ||
Apply fn recursively to each Node appearing in arg. | ||
|
||
arg may be a list, tuple, slice, or dict: the return value will have | ||
the same type and structure. | ||
""" | ||
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" | ||
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) | ||
|
||
|
||
@compatibility(is_backward_compatible=True) | ||
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: | ||
def map_aggregate(a: ArgumentT, fn: Callable[[Argument], Argument]) -> ArgumentT: | ||
""" | ||
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. | ||
Apply fn recursively to each object appearing in arg. | ||
|
||
arg may be a list, tuple, slice, or dict: the return value will have | ||
the same type and structure. | ||
""" | ||
result: Argument | ||
|
||
if isinstance(a, tuple): | ||
t = tuple([map_aggregate(elem, fn) for elem in a]) | ||
# Support NamedTuple (if it has `_fields`) by repacking into original type. | ||
return t if not hasattr(a, "_fields") else type(a)(*t) # type: ignore[arg-type] | ||
it = (map_aggregate(elem, fn) for elem in a) | ||
result = type(a)(*it) if torch._dynamo.utils.is_namedtuple(a) else type(a)(it) | ||
elif isinstance(a, list): | ||
return immutable_list([map_aggregate(elem, fn) for elem in a]) | ||
result = immutable_list(map_aggregate(elem, fn) for elem in a) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch compile might need this to be a listcomp There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Skylion007 Indeed, this was a problem, and so was the dict comprehension a few lines below. I shouldn't have monkeyed with it. There is work being done right now with dynamo to make it better handle generators of all types, so perhaps one day I'll come back to these for a lark and modernize them once it works. |
||
elif isinstance(a, dict): | ||
rv = immutable_dict() | ||
for k, v in a.items(): | ||
dict.__setitem__(rv, k, map_aggregate(v, fn)) | ||
return rv | ||
result = immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) | ||
elif isinstance(a, slice): | ||
return slice( | ||
result = slice( | ||
map_aggregate(a.start, fn), | ||
map_aggregate(a.stop, fn), | ||
map_aggregate(a.step, fn), | ||
) | ||
else: | ||
return fn(a) | ||
result = fn(a) | ||
|
||
return cast(ArgumentT, result) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, comparing identity with a TypeVar actually works? What if t is a generic type, it shouldn't right (unless T is somehow an unresolved TypeVar?)?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t
isn't a value:t
is a type, extracted from the signature of a function that's being checked for backward compatibility in_fn_to_stable_annotation_str
, a few pages above.In the two cases that this case is triggered,
t
actually containstorch.fx.node.ArgumentT
as taken right from insideinspect.signature().params
.And yes, it does seem to work, inasmuch as it does fix that backward compatibility test there, which failed before. I had a much less elegant bandaid there before I cottoned to this. 😁
Other test failures under study, stay tuned. Thanks as always for the critical eye!