8000 Make fx.node.map_arg() and .map_aggregate() generic by rec · Pull Request #146248 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Make fx.node.map_arg() and .map_aggregate() generic #146248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
6 changes: 5 additions & 1 deletion test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch.utils._pytree as pytree
import torch.fx._pytree as fx_pytree
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
from torch.fx.node import Target, Argument, _format_arg
from torch.fx.node import Target, Argument, ArgumentT, _format_arg
from torch.fx.passes import shape_prop
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.experimental.rewriter import RewritingTracer
Expand Down Expand Up @@ -4160,6 +4160,10 @@ def _annotation_type_to_stable_str(self, t, sig_str, recursive: bool = False):
else:
return f'Callable{contained_type_str}'

if t is ArgumentT:
Copy link
Collaborator
@Skylion007 Skylion007 Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, comparing identity with a TypeVar actually works? What if t is a generic type, it shouldn't right (unless T is somehow an unresolved TypeVar?)?

Copy link
Collaborator Author
@rec rec Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t isn't a value:t is a type, extracted from the signature of a function that's being checked for backward compatibility in _fn_to_stable_annotation_str, a few pages above.

In the two cases that this case is triggered, t actually contains torch.fx.node.ArgumentT as taken right from inside inspect.signature().params.

And yes, it does seem to work, inasmuch as it does fix that backward compatibility test there, which failed before. I had a much less elegant bandaid there before I cottoned to this. 😁

Other test failures under study, stay tuned. Thanks as always for the critical eye!

# ArgumentT is a TypeVar bound to torch.fx.node.Argument
return f'torch.fx.node.Argument{contained_type_str}'

raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
f'Please add support for this type and confirm with the '
f'FX team that your signature change is valid.')
Expand Down
11 changes: 3 additions & 8 deletions torch/distributed/pipelining/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import logging
import operator
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, cast, Optional, Union

import torch
import torch.distributed as dist
import torch.fx as fx
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensor
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.fx.node import map_aggregate
from torch.fx.node import Argument, map_aggregate
from torch.nn.parallel import DistributedDataParallel
from torch.utils._pytree import tree_map_only

Expand Down Expand Up @@ -538,12 +538,7 @@ def get_recv_tensor(info):
else:
raise AssertionError(f"Expected _RecvInfo but got {type(info)}")

tensors = map_aggregate(
recv_infos, # type: ignore[arg-type]
get_recv_tensor,
)

return tensors
return map_aggregate(cast(Argument, recv_infos), get_recv_tensor)

def _retrieve_recv_activations(self, fwd_chunk_id: int):
"""
Expand Down
37 changes: 22 additions & 15 deletions torch/fx/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import types
import warnings
from collections.abc import Mapping, Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union

import torch
from torch._C import _NodeBase
Expand Down Expand Up @@ -56,6 +56,7 @@
BaseArgumentTypes,
]
]
ArgumentT = TypeVar("ArgumentT", bound=Argument)

_legal_ops = dict.fromkeys(
[
Expand Down Expand Up @@ -890,35 +891,41 @@


@compatibility(is_backward_compatible=True)
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
def map_arg(a: ArgumentT, fn: Callable[[Node], Argument]) -> ArgumentT:

Check notice on line 894 in torch/fx/node.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function map_arg: a changed from Argument to ArgumentT
"""
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
Apply fn recursively to each Node appearing in arg.

arg may be a list, tuple, slice, or dict: the return value will have
the same type and structure.
"""
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)


@compatibility(is_backward_compatible=True)
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
def map_aggregate(a: ArgumentT, fn: Callable[[Argument], Argument]) -> ArgumentT:

Check notice on line 906 in torch/fx/node.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function map_aggregate: a changed from Argument to ArgumentT
"""
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
Apply fn recursively to each object appearing in arg.

arg may be a list, tuple, slice, or dict: the return value will have
the same type and structure.
"""
result: Argument

if isinstance(a, tuple):
t = tuple([map_aggregate(elem, fn) for elem in a])
# Support NamedTuple (if it has `_fields`) by repacking into original type.
return t if not hasattr(a, "_fields") else type(a)(*t) # type: ignore[arg-type]
it = (map_aggregate(elem, fn) for elem in a)
result = type(a)(*it) if torch._dynamo.utils.is_namedtuple(a) else type(a)(it)
elif isinstance(a, list):
return immutable_list([map_aggregate(elem, fn) for elem in a])
result = immutable_list(map_aggregate(elem, fn) for elem in a)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch compile might need this to be a listcomp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Skylion007 Indeed, this was a problem, and so was the dict comprehension a few lines below. I shouldn't have monkeyed with it.

There is work being done right now with dynamo to make it better handle generators of all types, so perhaps one day I'll come back to these for a lark and modernize them once it works.

elif isinstance(a, dict):
rv = immutable_dict()
for k, v in a.items():
dict.__setitem__(rv, k, map_aggregate(v, fn))
return rv
result = immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items())
elif isinstance(a, slice):
return slice(
result = slice(
map_aggregate(a.start, fn),
map_aggregate(a.stop, fn),
map_aggregate(a.step, fn),
)
else:
return fn(a)
result = fn(a)

return cast(ArgumentT, result)
Loading
0