8000 Make fx.node.map_arg() and .map_aggregate() generic · pytorch/pytorch@da0c725 · GitHub
[go: up one dir, main page]

Skip to content

Commit da0c725

Browse files
committed
Make fx.node.map_arg() and .map_aggregate() generic
ghstack-source-id: 7467ff5 Pull Request resolved: #146248
1 parent f95bdf5 commit da0c725

File tree

6 files changed

+73
-73
lines changed

6 files changed

+73
-73
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,43 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
66

77

88

9-
add_loop_inductor,compile_time_instruction_count,29630000000,0.015
9+
add_loop_inductor,compile_time_instruction_count,30000000000,0.015
1010

1111

1212

13-
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025
13+
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44000000000,0.025
1414

1515

1616

17-
add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015
17+
add_loop_inductor_gpu,compile_time_instruction_count,26500000000,0.015
1818

1919

2020

2121
basic_modules_ListOfLinears_eager,compile_time_instruction_count,977200000,0.015
2222

2323

2424

25-
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,0.015
25+
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19200000000,0.015
2626

2727

2828

29-
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015
29+
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17240000000,0.015
3030

3131

3232

33-
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10885050825,0.2
33+
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11000000000,0.2
3434

3535

3636

3737
update_hint_regression,compile_time_instruction_count,1686000000,0.02
3838

3939

4040

41-
sum_floordiv_regression,compile_time_instruction_count,1041000000,0.015
41+
sum_floordiv_regression,compile_time_instruction_count,1050000000,0.015
4242

4343

4444

45-
symint_sum,compile_time_instruction_count,3324000000,0.015
45+
symint_sum,compile_time_instruction_count,3360000000,0.015
4646

4747

4848

@@ -58,8 +58,8 @@ aotdispatcher_partitioner_cpu,compile_time_instruction_count,9167000000,0.015
5858

5959

6060

61-
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000,0.015
61+
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3900000000,0.015
6262

6363

6464

65-
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015
65+
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10450000000,0.015

test/test_fx.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,25 @@
1919
import typing
2020
import unittest
2121
import warnings
22-
from collections import namedtuple
23-
from copy import deepcopy
2422
from math import sqrt
23+
from torch.multiprocessing import Process
24+
from torch.testing import FileCheck
25+
from torch.testing._internal.common_methods_invocations import op_db
26+
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
27+
import torch.utils._pytree as pytree
28+
import torch.fx._pytree as fx_pytree
29+
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
30+
from torch.fx.node import Target, Argument, ArgumentT, _format_arg
31+
from torch.fx.passes import shape_prop
32+
from torch.fx.immutable_collections import immutable_dict, immutable_list
33+
from torch.fx.experimental.rewriter import RewritingTracer
34+
from torch.fx.operator_schemas import get_signature_for_torch_op
35+
from copy import deepcopy
36+
from collections import namedtuple
2537
from typing import Any, Callable, NamedTuple, Optional, Union
2638

2739
import torch
28-
import torch.fx._pytree as fx_pytree
29-
import torch.utils._pytree as pytree
40+
3041
from functorch.experimental import control_flow
3142

3243
from fx.named_tup import MyNamedTup
@@ -46,36 +57,10 @@
4657
from fx.test_pass_infra import TestPassManager # noqa: F401
4758
from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401
4859
from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401
49-
from torch.fx import (
50-
CodeGen,
51-
Graph,
52-
GraphModule,
53-
Interpreter,
54-
Node,
55-
PH,
56-
Proxy,
57-
symbolic_trace,
58-
Tracer,
59-
Transformer,
60-
wrap,
61-
)
6260
from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY
6361
from torch.fx._symbolic_trace import PHBase, PHWithMeta
64-
from torch.fx.experimental.rewriter import RewritingTracer
65-
from torch.fx.immutable_collections import immutable_dict, immutable_list
66-
from torch.fx.node import _format_arg, Argument, Target
67-
from torch.fx.operator_schemas import get_signature_for_torch_op
68-
from torch.fx.passes import shape_prop
6962

7063
from torch.fx.proxy import TraceError
71-
from torch.multiprocessing import Process
72-
from torch.testing import FileCheck
73-
from torch.testing._internal.common_device_type import (
74-
instantiate_device_type_tests,
75-
onlyCPU,
76-
ops,
77-
)
78-
from torch.testing._internal.common_methods_invocations import op_db
7964
from torch.testing._internal.common_utils import (
8065
find_library_location,
8166
IS_FBCODE,
@@ -4414,7 +4399,15 @@ def _annotation_type_to_stable_str(self, t, sig_str, recursive: bool = False):
44144399
if len(contained) > 0 and contained[0] is not Ellipsis:
44154400
return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
44164401
else:
4417-
return f"Callable{contained_type_str}"
4402+
return f'Callable{contained_type_str}'
4403+
4404+
if t is ArgumentT:
4405+
# ArgumentT is a TypeVar bound to torch.fx.node.Argument
4406+
return f'torch.fx.node.Argument{contained_type_str}'
4407+
4408+
raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
4409+
f'Please add support for this type and confirm with the '
4410+
f'FX team that your signature change is valid.')
44184411

44194412
raise RuntimeError(
44204413
f"Unrecognized type {t} used in BC-compatible type signature {sig_str}."

test/typing/pass/arithmetic_ops.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
assert_type(TENSOR // TENSOR, Any)
3131
assert_type(TENSOR / TENSOR, Tensor)
3232
assert_type(TENSOR % TENSOR, Tensor)
33-
assert_type(TENSOR**TENSOR, Any)
33+
assert_type(TENSOR**TENSOR, Tensor)
3434
assert_type(TENSOR << TENSOR, Tensor)
3535
assert_type(TENSOR >> TENSOR, Tensor)
3636
assert_type(TENSOR & TENSOR, Tensor)
@@ -49,7 +49,7 @@
4949
assert_type(TENSOR // BOOL, Any)
5050
assert_type(TENSOR / BOOL, Tensor)
5151
assert_type(TENSOR % BOOL, Tensor)
52-
assert_type(TENSOR**BOOL, Any)
52+
assert_type(TENSOR**BOOL, Tensor)
5353
assert_type(TENSOR << BOOL, Tensor)
5454
assert_type(TENSOR >> BOOL, Tensor)
5555
assert_type(TENSOR & BOOL, Tensor)
@@ -87,7 +87,7 @@
8787
assert_type(TENSOR // INT, Any)
8888
assert_type(TENSOR / INT, Tensor)
8989
assert_type(TENSOR % INT, Tensor)
90-
assert_type(TENSOR**INT, Any)
90+
assert_type(TENSOR**INT, Tensor)
9191
assert_type(TENSOR << INT, Tensor)
9292
assert_type(TENSOR >> INT, Tensor)
9393
assert_type(TENSOR & INT, Tensor)
@@ -125,7 +125,7 @@
125125
assert_type(TENSOR // FLOAT, Any)
126126
assert_type(TENSOR / FLOAT, Tensor)
127127
assert_type(TENSOR % FLOAT, Tensor)
128-
assert_type(TENSOR**FLOAT, Any)
128+
assert_type(TENSOR**FLOAT, Tensor)
129129
assert_type(TENSOR << FLOAT, Tensor)
130130
assert_type(TENSOR >> FLOAT, Tensor)
131131
assert_type(TENSOR & FLOAT, Tensor)
@@ -388,10 +388,6 @@ def __xor__(self, other: NUMBER) -> "Binary": # type: ignore[override]
388388
assert_type(BOOL**TENSOR, Any)
389389
assert_type(FLOAT**TENSOR, Any)
390390
assert_type(INT**TENSOR, Any)
391-
assert_type(TENSOR**BOOL, Any)
392-
assert_type(TENSOR**FLOAT, Any)
393-
assert_type(TENSOR**INT, Any)
394-
assert_type(TENSOR**TENSOR, Any)
395391

396392
assert_type(BOOL - TENSOR, Any)
397393
assert_type(FLOAT - TENSOR, Any)

torch/_tensor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import OrderedDict
77
from copy import deepcopy
88
from numbers import Number
9-
from typing import Any, Optional, Union
9+
from typing import Any, Callable, cast, Optional, Union
1010

1111
import torch
1212
import torch._C as _C
@@ -1104,8 +1104,14 @@ def __rdiv__(self, other):
11041104 1241
__rtruediv__ = __rdiv__
11051105
__itruediv__ = _C.TensorBase.__idiv__
11061106

1107-
__pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
1108-
_C.TensorBase.pow
1107+
__pow__ = cast(
1108+
Callable[
1109+
["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],
1110+
"Tensor",
1111+
],
1112+
_handle_torch_function_and_wrap_type_error_to_not_implemented(
1113+
_C.TensorBase.pow
1114+
),
11091115
)
11101116
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
11111117
_C.TensorBase.pow_

torch/distributed/pipelining/stage.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import logging
44
import operator
55
from abc import ABC, abstractmethod
6-
from typing import Any, Callable, Optional, Union
6+
from typing import Any, Callable, cast, Optional, Union
77

88
import torch
99
import torch.distributed as dist
1010
import torch.fx as fx
1111
import torch.nn as nn
1212
from torch._subclasses.fake_tensor import FakeTensor
1313
from torch.distributed.fsdp import FSDPModule, fully_shard
14-
from torch.fx.node import map_aggregate
14+
from torch.fx.node import Argument, map_aggregate
1515
from torch.nn.parallel import DistributedDataParallel
1616
from torch.utils._pytree import tree_map_only
1717

@@ -538,12 +538,7 @@ def get_recv_tensor(info):
538538
else:
539539
raise AssertionError(f"Expected _RecvInfo but got {type(info)}")
540540

541-
tensors = map_aggregate(
542-
recv_infos, # type: ignore[arg-type]
543-
get_recv_tensor,
544-
)
545-
546-
return tensors
541+
return map_aggregate(cast(Argument, recv_infos), get_recv_tensor)
547542

548543
def _retrieve_recv_activations(self, fwd_chunk_id: int):
549544
"""

torch/fx/node.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import types
66
import warnings
77
from collections.abc import Mapping, Sequence
8-
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
8+
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
99

1010
import torch
1111
from torch._C import _NodeBase
@@ -56,6 +56,7 @@
5656
BaseArgumentTypes,
5757
]
5858
]
59+
ArgumentT = TypeVar("ArgumentT", bound=Argument)
5960

6061
_legal_ops = dict.fromkeys(
6162
[
@@ -892,35 +893,44 @@ def __setattr__(self, name: str, value: Any) -> None:
892893

893894

894895
@compatibility(is_backward_compatible=True)
895-
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
896+
def map_arg(a: ArgumentT, fn: Callable[[Node], Argument]) -> ArgumentT:
896897
"""
897-
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
898+
Apply fn recursively to each Node appearing in arg.
899+
900+
arg may be a list, tuple, slice, or dict with string keys: the return value will
901+
have the same type and structure.
898902
"""
899903
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
900904
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
901905

902906

903907
@compatibility(is_backward_compatible=True)
904-
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
908+
def map_aggregate(a: ArgumentT, fn: Callable[[Argument], Argument]) -> ArgumentT:
905909
"""
906-
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
910+
Apply fn recursively to each object appearing in arg.
911+
912+
arg may be a list, tuple, slice, or dict with string keys: the return value will
913+
have the same type and structure.
907914
"""
915+
result: Argument
916+
908917
if isinstance(a, tuple):
909-
t = tuple([map_aggregate(elem, fn) for elem in a])
918+
it = (map_aggregate(elem, fn) for elem in a)
910919
# Support NamedTuple (if it has `_fields`) by repacking into original type.
911-
return t if not hasattr(a, "_fields") else type(a)(*t) # type: ignore[arg-type]
920+
result = type(a)(*it) if hasattr(a, "_fields") else tuple(it)
912921
elif isinstance(a, list):
913-
return immutable_list([map_aggregate(elem, fn) for elem in a])
922+
result = immutable_list([map_aggregate(elem, fn) for elem in a])
914923
elif isinstance(a, dict):
915-
rv = immutable_dict()
924+
result = immutable_dict()
916925
for k, v in a.items():
917-
dict.__setitem__(rv, k, map_aggregate(v, fn))
918-
return rv
926+
dict.__setitem__(result, k, map_aggregate(v, fn)) # type: ignore[index]
919927
elif isinstance(a, slice):
920-
return slice(
928+
result = slice(
921929
map_aggregate(a.start, fn),
922930
map_aggregate(a.stop, fn),
923931
map_aggregate(a.step, fn),
924932
)
925933
else:
926-
return fn(a)
934+
result = fn(a)
935+
936+
return cast(ArgumentT, result)

0 commit comments

Comments
 (0)
0