1
1
# mypy: allow-untyped-defs
2
2
from enum import Enum
3
- from typing import Any , Callable , List , Optional , Set
3
+ from typing import Any , Callable
4
4
5
5
import torch
6
-
7
- from ._profiler import (
6
+ from torch ._C ._profiler import (
8
7
_ProfilerEvent ,
9
8
ActiveProfilerType ,
10
9
ProfilerActivity ,
@@ -47,7 +46,7 @@ class ProfilerEvent:
47
46
def name (self ) -> str : ...
48
47
def node_id (self ) -> int : ...
49
48
def sequence_nr (self ) -> int : ...
50
- def shapes (self ) -> List [ List [int ]]: ...
49
+ def shapes (self ) -> list [ list [int ]]: ...
51
50
def thread_id (self ) -> int : ...
52
51
def flops (self ) -> float : ...
53
52
def is_async (self ) -> bool : ...
@@ -61,37 +60,37 @@ class _KinetoEvent:
61
60
def duration_ns (self ) -> int : ...
62
61
def is_async (self ) -> bool : ...
63
62
def linked_correlation_id (self ) -> int : ...
64
- def shapes (self ) -> List [ List [int ]]: ...
65
- def dtypes (self ) -> List [str ]: ...
66
- def concrete_inputs (self ) -> List [Any ]: ...
63
+ def shapes (self ) -> list [ list [int ]]: ...
64
+ def dtypes (self ) -> list [str ]: ...
65
+ def concrete_inputs (self ) -> list [Any ]: ...
67
66
def device_type (self ) -> DeviceType : ...
68
67
def start_thread_id (self ) -> int : ...
69
68
def end_thread_id (self ) -> int : ...
70
69
def correlation_id (self ) -> int : ...
71
70
def fwd_thread_id (self ) -> int : ...
72
- def stack (self ) -> List [str ]: ...
71
+ def stack (self ) -> list [str ]: ...
73
72
def scope (self ) -> int : ...
74
73
def sequence_nr (self ) -> int : ...
75
74
def flops (self ) -> int : ...
76
75
def cuda_elapsed_us (self ) -> int : ...
77
76
def privateuse1_elapsed_us (self ) -> int : ...
78
77
79
78
class _ProfilerResult :
80
- def events (self ) -> List [_KinetoEvent ]: ...
81
- def legacy_events (self ) -> List [ List [ProfilerEvent ]]: ...
79
+ def events (self ) -> list [_KinetoEvent ]: ...
80
+ def legacy_events (self ) -> list [ list [ProfilerEvent ]]: ...
82
81
def save (self , path : str ) -> None : ...
83
- def experimental_event_tree (self ) -> List [_ProfilerEvent ]: ...
82
+ def experimental_event_tree (self ) -> list [_ProfilerEvent ]: ...
84
83
def trace_start_ns (self ) -> int : ...
85
84
86
85
class SavedTensor : ...
87
86
88
87
def _enable_profiler (
89
88
config : ProfilerConfig ,
90
- activities : Set [ProfilerActivity ],
89
+ activities : set [ProfilerActivity ],
91
90
) -> None : ...
92
91
def _prepare_profiler (
93
92
config : ProfilerConfig ,
94
- activities : Set [ProfilerActivity ],
93
+ activities : set [ProfilerActivity ],
95
94
) -> None : ...
96
95
def _disable_profiler () -> _ProfilerResult : ...
97
96
def _profiler_enabled () -> bool : ...
@@ -101,7 +100,7 @@ def _get_sequence_nr() -> int: ...
101
100
def kineto_available () -> bool : ...
102
101
def _record_function_with_args_enter (name : str , * args ) -> torch .Tensor : ...
103
102
def _record_function_with_args_exit (handle : torch .Tensor ) -> None : ...
104
- def _supported_activities () -> Set [ProfilerActivity ]: ...
103
+ def _supported_activities () -> set [ProfilerActivity ]: ...
105
104
def _enable_record_function (enable : bool ) -> None : ...
106
105
def _set_empty_test_observer (is_global : bool , sampling_prob : float ) -> None : ...
107
106
def _push_saved_tensors_default_hooks (
@@ -111,11 +110,11 @@ def _push_saved_tensors_default_hooks(
111
110
def _pop_saved_tensors_default_hooks () -> None : ...
112
111
def _unsafe_set_version_counter (t : torch .Tensor , prev_version : int ) -> None : ...
113
112
def _enable_profiler_legacy (config : ProfilerConfig ) -> None : ...
114
- def _disable_profiler_legacy () -> List [ List [ProfilerEvent ]]: ...
113
+ def _disable_profiler_legacy () -> list [ list [ProfilerEvent ]]: ...
115
114
def _profiler_type () -> ActiveProfilerType : ...
116
115
def _saved_tensors_hooks_enable () -> None : ...
117
116
def _saved_tensors_hooks_disable (message : str ) -> None : ...
118
- def _saved_tensors_hooks_get_disabled_error_message () -> Optional [ str ] : ...
117
+ def _saved_tensors_hooks_get_disabled_error_message () -> str | None : ...
119
118
def _saved_tensors_hooks_set_tracing (is_tracing : bool ) -> bool : ...
120
119
121
120
class CreationMeta (Enum ):
0 commit comments