8000 Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in `.pyi` st… · pytorch/pytorch@e40f50c · GitHub
[go: up one dir, main page]

Skip to content

Commit e40f50c

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419)
------ - [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585): e.g. `typing.List[T] -> list[T]`, `typing.Dict[KT, VT] -> dict[KT, VT]`, `typing.Type[T] -> type[T]`. - [Union Type (PEP 604)](https://peps.python.org/pep-0604): e.g. `Union[X, Y] -> X | Y`, `Optional[X] -> X | None`, `Optional[Union[X, Y]] -> X | Y | None`. Note that in `.pyi` stub files, we do not need `from __future__ import annotations`. So this PR does not violate issue #117449: - #117449 Pull Request resolved: #129419 Approved by: https://github.com/ezyang ghstack dependencies: #129375, #129376
1 parent 494057d commit e40f50c

21 files changed

+414
-406
lines changed

torch/_C/_aoti.pyi

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
from ctypes import c_void_p
2-
from typing import List
32

43
from torch import Tensor
54

65
# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp
76

87
# Tensor to AtenTensorHandle
9-
def unsafe_alloc_void_ptrs_from_tensors(tensors: List[Tensor]) -> List[c_void_p]: ...
8+
def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ...
109
def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ...
1110

1211
# AtenTensorHandle to Tensor
1312
def alloc_tensors_by_stealing_from_void_ptrs(
14-
handles: List[c_void_p],
15-
) -> List[Tensor]: ...
13+
handles: list[c_void_p],
14+
) -> list[Tensor]: ...
1615
def alloc_tensor_by_stealing_from_void_ptr(
1716
handle: c_void_p,
1817
) -> Tensor: ...

torch/_C/_autograd.pyi

Lines changed: 15 additions & 16 deletions
EDBE
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# mypy: allow-untyped-defs
22
from enum import Enum
3-
from typing import Any, Callable, List, Optional, Set
3+
from typing import Any, Callable
44

55
import torch
6-
7-
from ._profiler import (
6+
from torch._C._profiler import (
87
_ProfilerEvent,
98
ActiveProfilerType,
109
ProfilerActivity,
@@ -47,7 +46,7 @@ class ProfilerEvent:
4746
def name(self) -> str: ...
4847
def node_id(self) -> int: ...
4948
def sequence_nr(self) -> int: ...
50-
def shapes(self) -> List[List[int]]: ...
49+
def shapes(self) -> list[list[int]]: ...
5150
def thread_id(self) -> int: ...
5251
def flops(self) -> float: ...
5352
def is_async(self) -> bool: ...
@@ -61,37 +60,37 @@ class _KinetoEvent:
6160
def duration_ns(self) -> int: ...
6261
def is_async(self) -> bool: ...
6362
def linked_correlation_id(self) -> int: ...
64-
def shapes(self) -> List[List[int]]: ...
65-
def dtypes(self) -> List[str]: ...
66-
def concrete_inputs(self) -> List[Any]: ...
63+
def shapes(self) -> list[list[int]]: ...
64+
def dtypes(self) -> list[str]: ...
65+
def concrete_inputs(self) -> list[Any]: ...
6766
def device_type(self) -> DeviceType: ...
6867
def start_thread_id(self) -> int: ...
6968
def end_thread_id(self) -> int: ...
7069
def correlation_id(self) -> int: ...
7170
def fwd_thread_id(self) -> int: ...
72-
def stack(self) -> List[str]: ...
71+
def stack(self) -> list[str]: ...
7372
def scope(self) -> int: ...
7473
def sequence_nr(self) -> int: ...
7574
def flops(self) -> int: ...
7675
def cuda_elapsed_us(self) -> int: ...
7776
def privateuse1_elapsed_us(self) -> int: ...
7877

7978
class _ProfilerResult:
80-
def events(self) -> List[_KinetoEvent]: ...
81-
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
79+
def events(self) -> list[_KinetoEvent]: ...
80+
def legacy_events(self) -> list[list[ProfilerEvent]]: ...
8281
def save(self, path: str) -> None: ...
83-
def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
82+
def experimental_event_tree(self) -> list[_ProfilerEvent]: ...
8483
def trace_start_ns(self) -> int: ...
8584

8685
class SavedTensor: ...
8786

8887
def _enable_profiler(
8988
config: ProfilerConfig,
90-
activities: Set[ProfilerActivity],
89+
activities: set[ProfilerActivity],
9190
) -> None: ...
9291
def _prepare_profiler(
9392
config: ProfilerConfig,
94-
activities: Set[ProfilerActivity],
93+
activities: set[ProfilerActivity],
9594
) -> None: ...
9695
def _disable_profiler() -> _ProfilerResult: ...
9796
def _profiler_enabled() -> bool: ...
@@ -101,7 +100,7 @@ def _get_sequence_nr() -> int: ...
101100
def kineto_available() -> bool: ...
102101
def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
103102
def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
104-
def _supported_activities() -> Set[ProfilerActivity]: ...
103+
def _supported_activities() -> set[ProfilerActivity]: ...
105104
def _enable_record_function(enable: bool) -> None: ...
106105
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
107106
def _push_saved_tensors_default_hooks(
@@ -111,11 +110,11 @@ def _push_saved_tensors_default_hooks(
111110
def _pop_saved_tensors_default_hooks() -> None: ...
112111
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
113112
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
114-
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
113+
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
115114
def _profiler_type() -> ActiveProfilerType: ...
116115
def _saved_tensors_hooks_enable() -> None: ...
117116
def _saved_tensors_hooks_disable(message: str) -> None: ...
118-
def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
117+
def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
119118
def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...
120119

121120
class CreationMeta(Enum):

torch/_C/_distributed_autograd.pyi

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# mypy: allow-untyped-defs
2-
from typing import Any, Dict, List, Set
2+
from typing import Any
33

44
import torch
55

66
# This module is defined in torch/csrc/distributed/autograd/init.cpp
77

88
class DistAutogradContext:
99
def _context_id(self) -> int: ...
10-
def _recv_functions(self) -> Dict[int, Any]: ...
11-
def _send_functions(self) -> Dict[int, Any]: ...
12-
def _known_worker_ids(self) -> Set[int]: ...
10+
def _recv_functions(self) -> dict[int, Any]: ...
11+
def _send_functions(self) -> dict[int, Any]: ...
12+
def _known_worker_ids(self) -> set[int]: ...
1313

1414
def _new_context() -> DistAutogradContext: ...
1515
def _release_context(context_id: int) -> None: ...
@@ -18,10 +18,10 @@ def _is_valid_context(worker_id: int) -> bool: ...
1818
def _retrieve_context(context_id: int) -> DistAutogradContext: ...
1919
def _current_context() -> DistAutogradContext: ...
2020
def _init(worker_id: int) -> None: ...
21-
def _get_debug_info() -> Dict[str, str]: ...
21+
def _get_debug_info() -> dict[str, str]: ...
2222
def backward(
2323
context_id: int,
24-
roots: List[torch.Tensor],
24+
roots: list[torch.Tensor],
2525
retain_graph=False,
2626
) -> None: ...
27-
def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ...
27+
def get_gradients(context_id: int) -> dict[torch.Tensor, torch.Tensor]: ...

0 commit comments

Comments
 (0)
0