8000 PEP585: Add noqa to necessary tests (#146391) · pytorch/pytorch@1f8ff94 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1f8ff94

Browse files
aorenstepytorchmergebot
authored andcommitted
PEP585: Add noqa to necessary tests (#146391)
Pull Request resolved: #146391 Approved by: https://github.com/justinchuby, https://github.com/Skylion007
1 parent b61032f commit 1f8ff94

File tree

7 files changed

+63
-31
lines changed

7 files changed

+63
-31
lines changed

test/onnx/test_onnxscript_runtime.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
44

5-
from typing import List
5+
from typing import Sequence
66

77
import onnx_test_common
88
import onnxscript
@@ -90,7 +90,11 @@ def forward(self, x, y, z):
9090

9191
@onnxscript.script(custom_opset)
9292
def layer_norm(
93-
X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
93+
X,
94+
axes: Sequence[int],
95+
weight: FLOAT[...],
96+
bias: FLOAT[...],
97+
eps: float,
9498
):
9599
mean = op.ReduceMean(X, axes=axes)
96100
D = X - mean # op.Sub(X, mean)

test/test_cpp_extensions_aot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def expected_return_type(func):
231231
Our Pybind functions have a signature of the form `() -> return_type`.
232232
"""
233233
# Imports needed for the `eval` below.
234-
from typing import List, Tuple # noqa: F401
234+
from typing import List, Tuple # noqa: F401, UP035
235235

236236
return eval(re.search("-> (.*)\n", func.__doc__).group(1))
237237

test/test_fx.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from collections import namedtuple
2323
from copy import deepcopy
2424
from math import sqrt
25-
from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union
25+
from typing import Any, Callable, NamedTuple, Optional, Union
2626

2727
import torch
2828
import torch.fx._pytree as fx_pytree
@@ -2270,10 +2270,19 @@ def test_typename_print(self):
22702270
graph: torch.fx.Graph = torch.fx.Graph()
22712271
x: torch.fx.Node = graph.create_node("placeholder", "x")
22722272
b: torch.fx.Node = graph.create_node(
2273-
"call_function", target=torch.relu, args=(x,), type_expr=List[float]
2273+
"call_function", target=torch.relu, args=(x,), type_expr=list[float]
22742274
)
22752275
output: torch.fx.Node = graph.output(b)
22762276

2277+
self.assertTrue('list[float]' in str(graph))
2278+
2279+
def test_typename_print_pre_pep585(self):
2280+
graph : torch.fx.Graph = torch.fx.Graph()
2281+
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2282+
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
2283+
type_expr=typing.List[float]) # noqa: UP006
2284+
output : torch.fx.Node = graph.output(b)
2285+
22772286
self.assertTrue("typing.List[float]" in str(graph))
22782287

22792288
def test_layout(self):
@@ -2922,6 +2931,19 @@ def other(self, x: list[str]) -> list[str]:
29222931
def forward(self, x: list[str]) -> list[str]:
29232932
return self.other(x)
29242933

2934+
traced = symbolic_trace(ReturnTypeModule())
2935+
self.assertIn("-> list[str]", traced._code)
2936+
scripted = torch.jit.script(traced)
2937+
self.assertIn("-> List[str]", scripted.code)
2938+
2939+
def test_return_type_exists_pre_pep585(self):
2940+
class ReturnTypeModule(torch.nn.Module):
2941+
def other(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006
2942+
return x
2943+
2944+
def forward(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006
2945+
return self.other(x)
2946+
29252947
traced = symbolic_trace(ReturnTypeModule())
29262948
self.assertIn("-> typing_List[str]", traced._code)
29272949
scripted = torch.jit.script(traced)
@@ -3735,7 +3757,7 @@ def test_annotation_with_future(self):
37353757
@unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
37363758
def test_annotations_empty_tuple(self):
37373759
class Foo(torch.nn.Module):
3738-
def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
3760+
def forward(self, x: typing.Tuple[()], y: typing.Tuple[str, typing.Tuple[()]]): # noqa: UP006
37393761
return "foo"
37403762

37413763
traced = torch.fx.symbolic_trace(Foo())
@@ -4320,10 +4342,10 @@ def default_val_str(val):
43204342
tuple,
43214343
type,
43224344
typing.Callable,
4323-
typing.Dict,
4324-
typing.List,
4325-
typing.Tuple,
4326-
typing.Type,
4345+
typing.Dict, # noqa: UP006
4346+
typing.List, # noqa: UP006
4347+
typing.Tuple, # noqa: UP006
4348+
typing.Type, # noqa: UP006
43274349
typing.Union,
43284350
}
43294351

test/test_fx_experimental.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import typing
1313
import unittest
1414
from types import BuiltinFunctionType
15-
from typing import Callable, List, NamedTuple, Optional, Union
15+
from typing import Callable, NamedTuple, Optional, Union
1616

1717
import torch
1818
import torch.fx.experimental.meta_tracer
@@ -1548,25 +1548,25 @@ def test_type_matches(self):
15481548
(Optional[list[int]], list[int]),
15491549
] + [
15501550
# pre-PEP585 signatures
1551-
(typing.List[int], int),
1552-
(typing.List[int], create_type_hint([int, int])),
1553-
(typing.List[int], create_type_hint((int, int))),
1554-
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
1551+
(typing.List[int], int), # noqa: UP006
1552+
(typing.List[int], create_type_hint([int, int])), # noqa: UP006
1553+
(typing.List[int], create_type_hint((int, int))), # noqa: UP006
1554+
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), # noqa: UP006
15551555
(
1556-
typing.List[torch.Tensor],
1556+
typing.List[torch.Tensor], # noqa: UP006
15571557
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
15581558
),
1559-
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
1560-
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
1561-
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
1559+
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), # noqa: UP006
1560+
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), # noqa: UP006
1561+
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), # noqa: UP006
15621562
(
1563-
typing.List[torch.Tensor],
1563+
typing.List[torch.Tensor], # noqa: UP006
15641564
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
15651565
),
1566-
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
1567-
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
1568-
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]),
1569-
(Optional[typing.List[int]], typing.List[int]),
1566+
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), # noqa: UP006
1567+
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), # noqa: UP006
1568+
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]), # noqa: UP006
1569+
(Optional[typing.List[int]], typing.List[int]), # noqa: UP006
15701570
]
15711571

15721572
for sig_type, arg_type in should_be_equal:
@@ -1575,7 +1575,7 @@ def test_type_matches(self):
15751575
should_fail = [
15761576
(int, float),
15771577
(Union[int, float], str),
1578-
(list[torch.Tensor], List[int]),
1578+
(list[torch.Tensor], typing.List[int]), # noqa: UP006
15791579
] + [
15801580
# pre-PEP585 signatures
15811581
(list[torch.Tensor], list[int]),

torch/autograd/grad_mode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# mypy: allow-untyped-defs
2-
from typing import Any, Tuple, Union
2+
from typing import Any, Union
33

44
import torch
55
from torch.utils._contextlib import (
@@ -386,7 +386,7 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
386386
387387
"""
388388

389-
def __init__(self, tensors: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> None:
389+
def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
390390
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
391391
assert isinstance(self.tensors, tuple)
392392
self.prev_versions = tuple(t._version for t in self.tensors)

torch/fx/graph.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,13 @@ def type_repr(o: Any):
455455

456456
typename = _type_repr(o)
457457

458-
if hasattr(o, "__origin__"):
459-
# This is a generic type, e.g. typing.List[torch.Tensor]
460-
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
458+
if origin_type := getattr(o, "__origin__", None):
459+
# list[...], typing.List[...], TensorType[...]
460+
461+
if isinstance(o, typing._GenericAlias): # type: ignore[attr-defined]
462+
# This is a generic pre-PEP585 type, e.g. typing.List[torch.Tensor]
463+
origin_type = _origin_type_map.get(origin_type, origin_type)
464+
461465
origin_typename = add_global(_type_repr(origin_type), origin_type)
462466

463467
if hasattr(o, "__args__"):

torch/fx/node.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def _type_repr(obj: object) -> str:
126126
typically enough to uniquely identify a type. For everything
127127
else, we fall back on repr(obj).
128128
"""
129-
if isinstance(obj, type):
129+
# Extension: If we don't ignore GenericAlias then `list[int]` will print
130+
# simply "list".
131+
if isinstance(obj, type) and not isinstance(obj, types.GenericAlias):
130132
if obj.__module__ == "builtins":
131133
return obj.__qualname__
132134
return f"{obj.__module__}.{obj.__qualname__}"

0 commit comments

Comments
 (0)
0