10000 [BE][Easy] enable postponed annotations in `torchgen` by XuehaiPan · Pull Request #129376 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[BE][Easy] enable postponed annotations in torchgen #129376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
8000 90 changes: 46 additions & 44 deletions torchgen/api/autograd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import re
from dataclasses import dataclass
from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple
from typing import cast, Sequence

from torchgen import local
from torchgen.api import cpp
Expand Down Expand Up @@ -48,16 +50,16 @@
original_formula: str

# Names of the arguments for which this formula calculates derivatives.
var_names: Tuple[str, ...]
var_names: tuple[str, ...]

# Saved inputs that are referenced by the formula.
saved_inputs: Tuple[SavedAttribute, ...]
saved_inputs: tuple[SavedAttribute, ...]

# Saved outputs that are referenced by the formula.
saved_outputs: Tuple[SavedAttribute, ...]
saved_outputs: tuple[SavedAttribute, ...]

# Gradients that are referenced by name in the formula.
named_gradients: Set[str]
named_gradients: set[str]


# Represents a forward formula that calculates forward derivatives
Expand All @@ -71,17 +73,17 @@

# Name of the output arguments for which this formula calculates forward
# derivatives
var_names: Tuple[str, ...]
var_names: tuple[str, ...]

# Type of the output arguments for which this formula calculates forward
# derivatives
var_types: Tuple[Type, ...]
var_types: tuple[Type, ...]

# Inputs for which the forward derivatives are required for this formula
required_inputs_fw_grad: Optional[Tuple[str, ...]]
required_inputs_fw_grad: tuple[str, ...] | None

# Inputs for which the primal is required for this formula
required_inputs_primal: Optional[Tuple[str, ...]]
required_inputs_primal: tuple[str, ...] | None

# Flag to specify if this formula requires the original value of self
# This is only used by inplace operations
Expand Down Expand Up @@ -116,7 +118,7 @@
# The name of the generated autograd function.
# It's set only if we will calculate a derivative, i.e.
# 'args_with_derivatives' is not empty.
op: Optional[str]
op: str | None

# The derivatives formulae for this function.
# Note that the length of this sequence is the number of differentiable inputs
Expand All @@ -138,7 +140,7 @@

# The named gradients that are used in any of the derivatives.
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
used_named_gradients: Set[str]
used_named_gradients: set[str]

# The function's input arguments for which it calculates derivatives.
# It's the union of 'var_names' of all 'derivatives', sorted by the
Expand All @@ -149,15 +151,15 @@
non_differentiable_arg_names: Sequence[str]

# Raw data read from derivatives.yaml.
output_differentiability: Optional[List[bool]]
output_differentiability: list[bool] | None

# output_differentiability in derivatives.yaml can be a list of
# conditions that express if the output is differentiable. In this case,
# the number of conditions must match the number of outputs
# (NB: we only support one condition right now).
# output_differentiability gets populated with True for each condition,
# while output_differentiability_conditions gets populated with the conditions
output_differentiability_conditions: Optional[List[str]]
output_differentiability_conditions: list[str] | None

@property
def has_derivatives(self) -> bool:
Expand All @@ -170,7 +172,7 @@
# See Note [Codegen'd {view}_copy Operators]
def create_view_copy_from_view_derivative(
self, g: NativeFunctionsViewGroup
) -> Optional["DifferentiabilityInfo"]:
) -> DifferentiabilityInfo | None:
if g.view_copy is None:
return None
f = g.view_copy
Expand Down Expand Up @@ -201,7 +203,7 @@
)


def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
if info is None:
return False
for derivative in info.derivatives:
Expand All @@ -211,11 +213,11 @@
return False


def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
return uses_ident(info, "retain_variables")


def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
return uses_ident(info, "grad")


Expand Down Expand Up @@ -253,8 +255,8 @@
@dataclass(frozen=True)
class NativeFunctionWithDifferentiabilityInfo:
func: NativeFunction
info: Optional[Dict[str, DifferentiabilityInfo]]
fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]
info: dict[str, DifferentiabilityInfo] | None
fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None


# TODO: Update comment below since it is out of date.
Expand Down Expand Up @@ -363,19 +365,19 @@
# TODO(crcrpar): Avoid hard coding "Default" ideally.
def gen_foreach_derivativeinfo(
foreach_function: NativeFunction,
functional_info_by_signature: Dict[
FunctionSchema, Dict[str, DifferentiabilityInfo]
functional_info_by_signature: dict[

Check notice on line 368 in torchgen/api/autograd.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function gen_foreach_derivativeinfo: functional_info_by_signature changed from Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] to dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
FunctionSchema, dict[str, DifferentiabilityInfo]
],
non_functional_info_by_signature: Dict[
FunctionSchema, Dict[str, DifferentiabilityInfo]
non_functional_info_by_signature: dict[

Check notice on line 371 in torchgen/api/autograd.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function gen_foreach_derivativeinfo: non_functional_info_by_signature changed from Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] to dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
FunctionSchema, dict[str, DifferentiabilityInfo]
],
dispatch_key: str = "Default",
) -> Tuple[Optional[DifferentiabilityInfo], bool]:
) -> tuple[DifferentiabilityInfo | None, bool]:
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.

The second return value indicates whether the info is generated in this function.
"""
ref_diff_info: Optional[DifferentiabilityInfo] = None
ref_diff_info: DifferentiabilityInfo | None = None

for function_schema, diff_info in functional_info_by_signature.items():
if not is_reference_for_foreach(foreach_function, function_schema):
Expand Down Expand Up @@ -485,13 +487,13 @@
if arg.name in all_var_names
]

forward_derivatives: List[ForwardDerivative] = []
forward_derivatives: list[ForwardDerivative] = []
fw_derivative: ForwardDerivative
for fw_derivative in ref_diff_info.forward_derivatives:
var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
var_types: List[Type] = list(fw_derivative.var_types)
required_inputs_fw_grad: List[str] = []
required_inputs_primal: List[str] = []
var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
var_types: list[Type] = list(fw_derivative.var_types)
required_inputs_fw_grad: list[str] = []
required_inputs_primal: list[str] = []
if fw_derivative.required_inputs_fw_grad is not None:
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
if fw_derivative.required_inputs_primal:
Expand Down Expand Up @@ -578,9 +580,9 @@


def match_differentiability_info(
native_functions: List[NativeFunction],
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
) -> List[NativeFunctionWithDifferentiabilityInfo]:
native_functions: list[NativeFunction],

Check notice on line 583 in torchgen/api/autograd.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function match_differentiability_info: native_functions changed from List[NativeFunction] to list[NativeFunction]
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],

Check notice on line 584 in torchgen/api/autograd.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function match_differentiability_info: differentiability_infos changed from Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] to dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
) -> list[NativeFunctionWithDifferentiabilityInfo]:
"""Sets the "derivative" key on declarations to matching autograd function
In-place functions will use the out-of-place derivative definition if there
is no in-place specific derivative.
Expand All @@ -599,7 +601,7 @@

def find_info(
f: NativeFunction,
) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
# Don't bother matching info to generated out= variants
if "generated" in f.tags and f.func.kind() == SchemaKind.out:
return None, False
Expand Down Expand Up @@ -653,7 +655,7 @@

return None, False

result: List[NativeFunctionWithDifferentiabilityInfo] = []
result: list[NativeFunctionWithDifferentiabilityInfo] = []
for f in native_functions:
info_dict, is_exact_match = find_info(f)

Expand All @@ -677,7 +679,7 @@
)
continue

fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {}
fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
for key, info in info_dict.items():
if not info.forward_derivatives:
fw_derivative_dict[key] = []
Expand Down Expand Up @@ -713,7 +715,7 @@
formula = fw_info.formula

def replace_self_with_original_self(formula: str, postfix: str) -> str:
def repl(m: Match[str]) -> str:
def repl(m: re.Match[str]) -> str:
return f"{m.group(1)}original_self{postfix}{m.group(2)}"

return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
Expand All @@ -734,7 +736,7 @@
formula = replace_self_with_original_self(formula, "_t")

# replace "result" from the formula by "self_p"
def repl(m: Match[str]) -> str:
def repl(m: re.Match[str]) -> str:
return f"{m.group(1)}self_p{m.group(2)}"

formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
Expand All @@ -758,8 +760,8 @@
# If there is a need, we can relax (2) to allow any op that has an in-place variant
is_single_method_on_self_t = False
directly_do_inplace = False
op_name: Optional[str] = None
between_parens: Optional[str] = None
op_name: str | None = None
between_parens: str | None = None
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
if match:
op_name, between_parens = match.group(1), match.group(2)
Expand Down Expand Up @@ -823,7 +825,7 @@


def is_differentiable(
name: str, type: Type, info: Optional[DifferentiabilityInfo]
name: str, type: Type, info: DifferentiabilityInfo | None
) -> bool:
return type.is_tensor_like() and (
info is None or name not in info.non_differentiable_arg_names
Expand All @@ -832,10 +834,10 @@

def gen_differentiable_outputs(
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
) -> List[DifferentiableOutput]:
) -> list[DifferentiableOutput]:
f = fn.func
info = fn.info[key] if fn.info else None
outputs: List[DifferentiableOutput] = [
outputs: list[DifferentiableOutput] = [
DifferentiableOutput(
name=name,
type=ret.type,
Expand All @@ -850,7 +852,7 @@
f"The length of output_differentiability ({len(output_differentiability)}), "
f"does not match the number of outputs ({len(outputs)})."
)
differentiable_outputs: List[DifferentiableOutput] = []
differentiable_outputs: list[DifferentiableOutput] = []
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
raise RuntimeError(
"output_differentiability=False for inplace operation (version_counter won't get updated)"
Expand Down
26 changes: 14 additions & 12 deletions torchgen/api/cpp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Optional, Sequence, Set, Union
from __future__ import annotations

from typing import Sequence

from torchgen import local
from torchgen.api.types import (
Expand Down Expand Up @@ -94,7 +96,7 @@
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = False,
) -> Optional[NamedCType]:
) -> NamedCType | None:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
Expand Down Expand Up @@ -279,7 +281,7 @@


def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: List[str] = []
returns: list[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
Expand Down Expand Up @@ -367,18 +369,18 @@
# Convert an argument into its C++ API form


def argument(

Check notice on line 372 in torchgen/api/cpp.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function argument: cpp_no_default_args changed from Set[str] to set[str]
a: Union[Argument, TensorOptionsArguments, SelfArgument],
a: Argument | TensorOptionsArguments | SelfArgument,
*,
cpp_no_default_args: Set[str],
cpp_no_default_args: set[str],
method: bool,
faithful: bool,
symint: bool = False,
has_tensor_options: bool,
) -> List[Binding]:
) -> list[Binding]:
def sub_argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument]
) -> List[Binding]:
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Binding]:
return argument(
a,
cpp_no_default_args=cpp_no_default_args,
Expand All @@ -394,7 +396,7 @@
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: Optional[str] = None
default: str | None = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type, symint=symint)
return [
Expand Down Expand Up @@ -439,15 +441,15 @@
assert_never(a)


def arguments(

Check notice on line 444 in torchgen/api/cpp.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function arguments: cpp_no_default_args changed from Set[str] to set[str]
arguments: Arguments,
*,
faithful: bool,
symint: bool = False,
method: bool,
cpp_no_default_args: Set[str],
) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
cpp_no_default_args: set[str],
) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)
Expand Down
12 changes: 7 additions & 5 deletions torchgen/api/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import itertools
from typing import List, Sequence, Union
from typing import Sequence

from torchgen.api import cpp
from torchgen.api.types import ArgName, Binding, CType, NamedCType
Expand Down Expand Up @@ -76,10 +78,10 @@ def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
return cpp.returns_type(rs, symint=symint)


def jit_arguments(func: FunctionSchema) -> List[Argument]:
def jit_arguments(func: FunctionSchema) -> list[Argument]:
def to_argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument]
) -> List[Argument]:
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Argument]:
if isinstance(a, Argument):
return [a]
elif isinstance(a, SelfArgument):
Expand Down Expand Up @@ -114,5 +116,5 @@ def argument(
)


def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
return [argument(a, symint=symint) for a in jit_arguments(func)]
Loading
Loading
0