8000 Revert "[BE][Easy] enable postponed annotations in `torchgen` (#129376)" · pytorch/pytorch@6063bb9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6063bb9

Browse files
Revert "[BE][Easy] enable postponed annotations in torchgen (#129376)"
This reverts commit 494057d. Reverted #129376 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I need to revert to cleanly revert #129374, please do a rebase and reland this ([comment](#129375 (comment)))
1 parent 83caf49 commit 6063bb9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+900
-976
lines changed

torchgen/api/autograd.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from __future__ import annotations
2-
31
import re
42
from dataclasses import dataclass
5-
from typing import cast, Sequence
3+
from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple
64

75
from torchgen import local
86
from torchgen.api import cpp
@@ -50,16 +48,16 @@ class Derivative:
5048
original_formula: str
5149

5250
# Names of the arguments for which this formula calculates derivatives.
53-
var_names: tuple[str, ...]
51+
var_names: Tuple[str, ...]
5452

5553
# Saved inputs that are referenced by the formula.
56-
saved_inputs: tuple[SavedAttribute, ...]
54+
saved_inputs: Tuple[SavedAttribute, ...]
5755

5856
# Saved outputs that are referenced by the formula.
59-
saved_outputs: tuple[SavedAttribute, ...]
57+
saved_outputs: Tuple[SavedAttribute, ...]
6058

6159
# Gradients that are referenced by name in the formula.
62-
named_gradients: set[str]
60+
named_gradients: Set[str]
6361

6462

6563
# Represents a forward formula that calculates forward derivatives
@@ -73,17 +71,17 @@ class ForwardDerivative:
7371

7472
# Name of the output arguments for which this formula calculates forward
7573
# derivatives
76-
var_names: tuple[str, ...]
74+
var_names: Tuple[str, ...]
7775

7876
# Type of the output arguments for which this formula calculates forward
7977
# derivatives
80-
var_types: tuple[Type, ...]
78+
var_types: Tuple[Type, ...]
8179

8280
# Inputs for which the forward derivatives are required for this formula
83-
required_inputs_fw_grad: tuple[str, ...] | None
81+
required_inputs_fw_grad: Optional[Tuple[str, ...]]
8482

8583
# Inputs for which the primal is required for this formula
86-
required_inputs_primal: tuple[str, ...] | None
84+
required_inputs_primal: Optional[Tuple[str, ...]]
8785

8886
# Flag to specify if this formula requires the original value of self
8987
# This is only used by inplace operations
@@ -118,7 +116,7 @@ class DifferentiabilityInfo:
118116
# The name of the generated autograd function.
119117
# It's set only if we will calculate a derivative, i.e.
120118
# 'args_with_derivatives' is not empty.
121-
op: str | None
119+
op: Optional[str]
122120

123121
# The derivatives formulae for this function.
124122
# Note that the length of this sequence is the number of differentiable inputs
@@ -140,7 +138,7 @@ class DifferentiabilityInfo:
140138

141139
# The named gradients that are used in any of the derivatives.
142140
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
143-
used_named_gradients: set[str]
141+
used_named_gradients: Set[str]
144142

145143
# The function's input arguments for which it calculates derivatives.
146144
# It's the union of 'var_names' of all 'derivatives', sorted by the
@@ -151,15 +149,15 @@ class DifferentiabilityInfo:
151149
non_differentiable_arg_names: Sequence[str]
152150

153151
# Raw data read from derivatives.yaml.
154-
output_differentiability: list[bool] | None
152+
output_differentiability: Optional[List[bool]]
155153

156154
# output_differentiability in derivatives.yaml can be a list of
157155
# conditions that express if the output is differentiable. In this case,
158156
# the number of conditions must match the number of outputs
159157
# (NB: we only support one condition right now).
160158
# output_differentiability gets populated with True for each condition,
161159
# while output_differentiability_conditions gets populated with the conditions
162-
output_differentiability_conditions: list[str] | None
160+
output_differentiability_conditions: Optional[List[str]]
163161

164162
@property
165163
def has_derivatives(self) -> bool:
@@ -172,7 +170,7 @@ def has_derivatives(self) -> bool:
172170
# See Note [Codegen'd {view}_copy Operators]
173171
def create_view_copy_from_view_derivative(
174172
self, g: NativeFunctionsViewGroup
175-
) -> DifferentiabilityInfo | None:
173+
) -> Optional["DifferentiabilityInfo"]:
176174
if g.view_copy is None:
177175
return None
178176
f = g.view_copy
@@ -203,7 +201,7 @@ def create_view_copy_from_view_derivative(
203201
)
204202

205203

206-
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
204+
def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
207205
if info is None:
208206
return False
209207
for derivative in info.derivatives:
@@ -213,11 +211,11 @@ def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
213211
return False
214212

215213

216-
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
214+
def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
217215
return uses_ident(info, "retain_variables")
218216

219217

220-
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
218+
def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
221219
return uses_ident(info, "grad")
222220

223221

@@ -255,8 +253,8 @@ class DifferentiableOutput:
255253
@dataclass(frozen=True)
256254
class NativeFunctionWithDifferentiabilityInfo:
257255
func: NativeFunction
258-
info: dict[str, DifferentiabilityInfo] | None
259-
fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
256+
info: Optional[Dict[str, DifferentiabilityInfo]]
257+
fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]
260258

261259

262260
# TODO: Update comment below since it is out of date.
@@ -365,19 +363,19 @@ def is_reference_for_foreach(
365363
# TODO(crcrpar): Avoid hard coding "Default" ideally.
366364
def gen_foreach_derivativeinfo(
367365
foreach_function: NativeFunction,
368-
functional_info_by_signature: dict[
369-
FunctionSchema, dict[str, DifferentiabilityInfo]
366+
functional_info_by_signature: Dict[
367+
FunctionSchema, Dict[str, DifferentiabilityInfo]
370368
],
371-
non_functional_info_by_signature: dict[
372-
FunctionSchema, dict[str, DifferentiabilityInfo]
369+
non_functional_info_by_signature: Dict[
370+
FunctionSchema, Dict[str, DifferentiabilityInfo]
373371
],
374372
dispatch_key: str = "Default",
375-
) -> tuple[DifferentiabilityInfo | None, bool]:
373+
) -> Tuple[Optional[DifferentiabilityInfo], bool]:
376374
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
377375
378376
The second return value indicates whether the info is generated in this function.
379377
"""
380-
ref_diff_info: DifferentiabilityInfo | None = None
378+
ref_diff_info: Optional[DifferentiabilityInfo] = None
381379

382380
for function_schema, diff_info in functional_info_by_signature.items():
383381
if not is_reference_for_foreach(foreach_function, function_schema):
@@ -487,13 +485,13 @@ def gen_foreach_derivativeinfo(
487485
if arg.name in all_var_names
488486
]
489487

490-
forward_derivatives: list[ForwardDerivative] = []
488+
forward_derivatives: List[ForwardDerivative] = []
491489
fw_derivative: ForwardDerivative
492490
for fw_derivative in ref_diff_info.forward_derivatives:
493-
var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
494-
var_types: list[Type] = list(fw_derivative.var_types)
495-
required_inputs_fw_grad: list[str] = []
496-
required_inputs_primal: list[str] = []
491+
var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
492+
var_types: List[Type] = list(fw_derivative.var_types)
493+
required_inputs_fw_grad: List[str] = []
494+
required_inputs_primal: List[str] = []
497495
if fw_derivative.required_inputs_fw_grad is not None:
498496
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
499497
if fw_derivative.required_inputs_primal:
@@ -580,9 +578,9 @@ def gen_foreach_derivativeinfo(
580578

581579

582580
def match_differentiability_info(
583-
native_functions: list[NativeFunction],
584-
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
585-
) -> list[NativeFunctionWithDifferentiabilityInfo]:
581+
native_functions: List[NativeFunction],
582+
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
583+
) -> List[NativeFunctionWithDifferentiabilityInfo]:
586584
"""Sets the "derivative" key on declarations to matching autograd function
587585
In-place functions will use the out-of-place derivative definition if there
588586
is no in-place specific derivative.
@@ -601,7 +599,7 @@ def match_differentiability_info(
601599

602600
def find_info(
603601
f: NativeFunction,
604-
) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
602+
) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
605603
# Don't bother matching info to generated out= variants
606604
if "generated" in f.tags and f.func.kind() == SchemaKind.out:
607605
return None, False
@@ -655,7 +653,7 @@ def find_info(
655653

656654
return None, False
657655

658-
result: list[NativeFunctionWithDifferentiabilityInfo] = []
656+
result: List[NativeFunctionWithDifferentiabilityInfo] = []
659657
for f in native_functions:
660658
info_dict, is_exact_match = find_info(f)
661659

@@ -679,7 +677,7 @@ def find_info(
679677
)
680678
continue
681679

682-
fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
680+
fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {}
683681
for key, info in info_dict.items():
684682
if not info.forward_derivatives:
685683
fw_derivative_dict[key] = []
@@ -715,7 +713,7 @@ def find_info(
715713
formula = fw_info.formula
716714

717715
def replace_self_with_original_self(formula: str, postfix: str) -> str:
718-
def repl(m: re.Match[str]) -> str:
716+
def repl(m: Match[str]) -> str:
719717
return f"{m.group(1)}original_self{postfix}{m.group(2)}"
720718

721719
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
@@ -736,7 +734,7 @@ def repl(m: re.Match[str]) -> str:
736734
formula = replace_self_with_original_self(formula, "_t")
737735

738736
# replace "result" from the formula by "self_p"
739-
def repl(m: re.Match[str]) -> str:
737+
def repl(m: Match[str]) -> str:
740738
return f"{m.group(1)}self_p{m.group(2)}"
741739

742740
formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
@@ -760,8 +758,8 @@ def repl(m: re.Match[str]) -> str:
760758
# If there is a need, we can relax (2) to allow any op that has an in-place variant
761759
is_single_method_on_self_t = False
762760
directly_do_inplace = False
763-
op_name: str | None = None
764-
between_parens: str | None = None
761+
op_name: Optional[str] = None
762+
between_parens: Optional[str] = None
765763
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
766764
if match:
767765
op_name, between_parens = match.group(1), match.group(2)
@@ -825,7 +823,7 @@ def check_parens_nest_level_gt_zero(s: str) -> bool:
825823

826824

827825
def is_differentiable(
828-
name: str, type: Type, info: DifferentiabilityInfo | None
826+
name: str, type: Type, info: Optional[DifferentiabilityInfo]
829827
) -> bool:
830828
return type.is_tensor_like() and (
831829
info is None or name not in info.non_differentiable_arg_names
@@ -834,10 +832,10 @@ def is_differentiable(
834832

835833
def gen_differentiable_outputs(
836834
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
837-
) -> list[DifferentiableOutput]:
835+
) -> List[DifferentiableOutput]:
838836
f = fn.func
839837
info = fn.info[key] if fn.info else None
840-
outputs: list[DifferentiableOutput] = [
838+
outputs: List[DifferentiableOutput] = [
841839
DifferentiableOutput(
842840
name=name,
843841
type=ret.type,
@@ -852,7 +850,7 @@ def gen_differentiable_outputs(
852850
f"The length of output_differentiability ({len(output_differentiability)}), "
853851
f"does not match the number of outputs ({len(outputs)})."
854852
)
855-
differentiable_outputs: list[DifferentiableOutput] = []
853+
differentiable_outputs: List[DifferentiableOutput] = []
856854
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
857855
raise RuntimeError(
858856
"output_differentiability=False for inplace operation (version_counter won't get updated)"

torchgen/api/cpp.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from __future__ import annotations
2-
3-
from typing import Sequence
1+
from typing import List, Optional, Sequence, Set, Union
42

53
from torchgen import local
64
from torchgen.api.types import (
@@ -96,7 +94,7 @@ def valuetype_type(
9694
binds: ArgName,
9795
remove_non_owning_ref_types: bool = False,
9896
symint: bool = False,
99-
) -> NamedCType | None:
97+
) -> Optional[NamedCType]:
10098
if isinstance(t, BaseType):
10199
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
102100
return None
@@ -281,7 +279,7 @@ def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
281279

282280

283281
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
284-
returns: list[str] = []
282+
returns: List[str] = []
285283
for i, r in enumerate(f.func.returns):
286284
# If we have an inplace function, the return argument is
287285
# implicitly named self.
@@ -370,17 +368,17 @@ def default_expr(d: str, t: Type, *, symint: bool) -> str:
370368

371369

372370
def argument(
373-
a: Argument | TensorOptionsArguments | SelfArgument,
371+
a: Union[Argument, TensorOptionsArguments, SelfArgument],
374372
*,
375-
cpp_no_default_args: set[str],
373+
cpp_no_default_args: Set[str],
376374
method: bool,
377375
faithful: bool,
378376
symint: bool = False,
379377
has_tensor_options: bool,
380-
) -> list[Binding]:
378+
) -> List[Binding]:
381379
def sub_argument(
382-
a: Argument | TensorOptionsArguments | SelfArgument,
383-
) -> list[Binding]:
380+
a: Union[Argument, TensorOptionsArguments, SelfArgument]
381+
) -> List[Binding]:
384382
return argument(
385383
a,
386384
cpp_no_default_args=cpp_no_default_args,
@@ -396,7 +394,7 @@ def sub_argument(
396394
binds = SpecialArgName.possibly_redundant_memory_format
397395
else:
398396
binds = a.name
399-
default: str | None = None
397+
default: Optional[str] = None
400398
if a.name not in cpp_no_default_args and a.default is not None:
401399
default = default_expr(a.default, a.type, symint=symint)
402400
return [
@@ -447,9 +445,9 @@ def arguments(
447445
faithful: bool,
448446
symint: bool = False,
449447
method: bool,
450-
cpp_no_default_args: set[str],
451-
) -> list[Binding]:
452-
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
448+
cpp_no_default_args: Set[str],
449+
) -> List[Binding]:
450+
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
453451
if faithful:
454452
args.extend(arguments.non_out)
455453
args.extend(arguments.out)

torchgen/api/dispatcher.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
from __future__ import annotations
2-
31
import itertools
4-
from typing import Sequence
2+
from typing import List, Sequence, Union
53

64
from torchgen.api import cpp
75
from torchgen.api.types import ArgName, Binding, CType, NamedCType
@@ -78,10 +76,10 @@ def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
7876
return cpp.returns_type(rs, symint=symint)
7977

8078

81-
def jit_arguments(func: FunctionSchema) -> list[Argument]:
79+
def jit_arguments(func: FunctionSchema) -> List[Argument]:
8280
def to_argument(
83-
a: Argument | TensorOptionsArguments | SelfArgument,
84-
) -> list[Argument]:
81+
a: Union[Argument, TensorOptionsArguments, SelfArgument]
82+
) -> List[Argument]:
8583
if isinstance(a, Argument):
8684
return [a]
8785
elif isinstance(a, SelfArgument):
@@ -116,5 +114,5 @@ def argument(
116114
)
117115

118116

119-
def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
117+
def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
120118
return [argument(a, symint=symint) for a in jit_arguments(func)]

0 commit comments

Comments
 (0)
290F
0