1
- from __future__ import annotations
2
-
3
1
import re
4
2
from dataclasses import dataclass
5
- from typing import cast , Sequence
3
+ from typing import cast , Dict , List , Match , Optional , Sequence , Set , Tuple
6
4
7
5
from torchgen import local
8
6
from torchgen .api import cpp
@@ -50,16 +48,16 @@ class Derivative:
50
48
original_formula : str
51
49
52
50
# Names of the arguments for which this formula calculates derivatives.
53
- var_names : tuple [str , ...]
51
+ var_names : Tuple [str , ...]
54
52
55
53
# Saved inputs that are referenced by the formula.
56
- saved_inputs : tuple [SavedAttribute , ...]
54
+ saved_inputs : Tuple [SavedAttribute , ...]
57
55
58
56
# Saved outputs that are referenced by the formula.
59
- saved_outputs : tuple [SavedAttribute , ...]
57
+ saved_outputs : Tuple [SavedAttribute , ...]
60
58
61
59
# Gradients that are referenced by name in the formula.
62
- named_gradients : set [str ]
60
+ named_gradients : Set [str ]
63
61
64
62
65
63
# Represents a forward formula that calculates forward derivatives
@@ -73,17 +71,17 @@ class ForwardDerivative:
73
71
74
72
# Name of the output arguments for which this formula calculates forward
75
73
# derivatives
76
- var_names : tuple [str , ...]
74
+ var_names : Tuple [str , ...]
77
75
78
76
# Type of the output arguments for which this formula calculates forward
79
77
# derivatives
80
- var_types : tuple [Type , ...]
78
+ var_types : Tuple [Type , ...]
81
79
82
80
# Inputs for which the forward derivatives are required for this formula
83
- required_inputs_fw_grad : tuple [ str , ...] | None
81
+ required_inputs_fw_grad : Optional [ Tuple [ str , ...]]
84
82
85
83
# Inputs for which the primal is required for this formula
86
- required_inputs_primal : tuple [ str , ...] | None
84
+ required_inputs_primal : Optional [ Tuple [ str , ...]]
87
85
88
86
# Flag to specify if this formula requires the original value of self
89
87
# This is only used by inplace operations
@@ -118,7 +116,7 @@ class DifferentiabilityInfo:
118
116
# The name of the generated autograd function.
119
117
# It's set only if we will calculate a derivative, i.e.
120
118
# 'args_with_derivatives' is not empty.
121
- op : str | None
119
+ op : Optional [ str ]
122
120
123
121
# The derivatives formulae for this function.
124
122
# Note that the length of this sequence is the number of differentiable inputs
@@ -140,7 +138,7 @@ class DifferentiabilityInfo:
140
138
141
139
# The named gradients that are used in any of the derivatives.
142
140
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
143
- used_named_gradients : set [str ]
141
+ used_named_gradients : Set [str ]
144
142
145
143
# The function's input arguments for which it calculates derivatives.
146
144
# It's the union of 'var_names' of all 'derivatives', sorted by the
@@ -151,15 +149,15 @@ class DifferentiabilityInfo:
151
149
non_differentiable_arg_names : Sequence [str ]
152
150
153
151
# Raw data read from derivatives.yaml.
154
- output_differentiability : list [ bool ] | None
152
+ output_differentiability : Optional [ List [ bool ]]
155
153
156
154
# output_differentiability in derivatives.yaml can be a list of
157
155
# conditions that express if the output is differentiable. In this case,
158
156
# the number of conditions must match the number of outputs
159
157
# (NB: we only support one condition right now).
160
158
# output_differentiability gets populated with True for each condition,
161
159
# while output_differentiability_conditions gets populated with the conditions
162
- output_differentiability_conditions : list [ str ] | None
160
+ output_differentiability_conditions : Optional [ List [ str ]]
163
161
164
162
@property
165
163
def has_derivatives (self ) -> bool :
@@ -172,7 +170,7 @@ def has_derivatives(self) -> bool:
172
170
# See Note [Codegen'd {view}_copy Operators]
173
171
def create_view_copy_from_view_derivative (
174
172
self , g : NativeFunctionsViewGroup
175
- ) -> DifferentiabilityInfo | None :
173
+ ) -> Optional [ " DifferentiabilityInfo" ] :
176
174
if g .view_copy is None :
177
175
return None
178
176
f = g .view_copy
@@ -203,7 +201,7 @@ def create_view_copy_from_view_derivative(
203
201
)
204
202
205
203
206
- def uses_ident (info : DifferentiabilityInfo | None , ident : str ) -> bool :
204
+ def uses_ident (info : Optional [ DifferentiabilityInfo ] , ident : str ) -> bool :
207
205
if info is None :
208
206
return False
209
207
for derivative in info .derivatives :
@@ -213,11 +211,11 @@ def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
213
211
return False
214
212
215
213
216
- def uses_retain_variables (info : DifferentiabilityInfo | None ) -> bool :
214
+ def uses_retain_variables (info : Optional [ DifferentiabilityInfo ] ) -> bool :
217
215
return uses_ident (info , "retain_variables" )
218
216
219
217
220
- def uses_single_grad (info : DifferentiabilityInfo | None ) -> bool :
218
+ def uses_single_grad (info : Optional [ DifferentiabilityInfo ] ) -> bool :
221
219
return uses_ident (info , "grad" )
222
220
223
221
@@ -255,8 +253,8 @@ class DifferentiableOutput:
255
253
@dataclass (frozen = True )
256
254
class NativeFunctionWithDifferentiabilityInfo :
257
255
func : NativeFunction
258
- info : dict [ str , DifferentiabilityInfo ] | None
259
- fw_derivatives : dict [ str , Sequence [ForwardDerivative ]] | None
256
+ info : Optional [ Dict [ str , DifferentiabilityInfo ]]
257
+ fw_derivatives : Optional [ Dict [ str , Sequence [ForwardDerivative ]]]
260
258
261
259
262
260
# TODO: Update comment below since it is out of date.
@@ -365,19 +363,19 @@ def is_reference_for_foreach(
365
363
# TODO(crcrpar): Avoid hard coding "Default" ideally.
366
364
def gen_foreach_derivativeinfo (
367
365
foreach_function : NativeFunction ,
368
- functional_info_by_signature : dict [
369
- FunctionSchema , dict [str , DifferentiabilityInfo ]
366
+ functional_info_by_signature : Dict [
367
+ FunctionSchema , Dict [str , DifferentiabilityInfo ]
370
368
],
371
- non_functional_info_by_signature : dict [
372
- FunctionSchema , dict [str , DifferentiabilityInfo ]
369
+ non_functional_info_by_signature : Dict [
370
+ FunctionSchema , Dict [str , DifferentiabilityInfo ]
373
371
],
374
372
dispatch_key : str = "Default" ,
375
- ) -> tuple [ DifferentiabilityInfo | None , bool ]:
373
+ ) -> Tuple [ Optional [ DifferentiabilityInfo ] , bool ]:
376
374
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
377
375
378
376
The second return value indicates whether the info is generated in this function.
379
377
"""
380
- ref_diff_info : DifferentiabilityInfo | None = None
378
+ ref_diff_info : Optional [ DifferentiabilityInfo ] = None
381
379
382
380
for function_schema , diff_info in functional_info_by_signature .items ():
383
381
if not is_reference_for_foreach (foreach_function , function_schema ):
@@ -487,13 +485,13 @@ def gen_foreach_derivativeinfo(
487
485
if arg .name in all_var_names
488
486
]
489
487
490
- forward_derivatives : list [ForwardDerivative ] = []
488
+ forward_derivatives : List [ForwardDerivative ] = []
491
489
fw_derivative : ForwardDerivative
492
490
for fw_derivative in ref_diff_info .forward_derivatives :
493
- var_names : list [str ] = list (fw_derivative .var_names ) # type: ignore[no-redef]
494
- var_types : list [Type ] = list (fw_derivative .var_types )
495
- required_inputs_fw_grad : list [str ] = []
496
- required_inputs_primal : list [str ] = []
491
+ var_names : List [str ] = list (fw_derivative .var_names ) # type: ignore[no-redef]
492
+ var_types : List [Type ] = list (fw_derivative .var_types )
493
+ required_inputs_fw_grad : List [str ] = []
494
+ required_inputs_primal : List [str ] = []
497
495
if fw_derivative .required_inputs_fw_grad is not None :
498
496
required_inputs_fw_grad = list (fw_derivative .required_inputs_fw_grad )
499
497
if fw_derivative .required_inputs_primal :
@@ -580,9 +578,9 @@ def gen_foreach_derivativeinfo(
580
578
581
579
582
580
def match_differentiability_info (
583
- native_functions : list [NativeFunction ],
584
- differentiability_infos : dict [FunctionSchema , dict [str , DifferentiabilityInfo ]],
585
- ) -> list [NativeFunctionWithDifferentiabilityInfo ]:
581
+ native_functions : List [NativeFunction ],
582
+ differentiability_infos : Dict [FunctionSchema , Dict [str , DifferentiabilityInfo ]],
583
+ ) -> List [NativeFunctionWithDifferentiabilityInfo ]:
586
584
"""Sets the "derivative" key on declarations to matching autograd function
587
585
In-place functions will use the out-of-place derivative definition if there
588
586
is no in-place specific derivative.
@@ -601,7 +599,7 @@ def match_differentiability_info(
601
599
602
600
def find_info (
603
601
f : NativeFunction ,
604
- ) -> tuple [ dict [ str , DifferentiabilityInfo ] | None , bool ]:
602
+ ) -> Tuple [ Optional [ Dict [ str , DifferentiabilityInfo ]] , bool ]:
605
603
# Don't bother matching info to generated out= variants
606
604
if "generated" in f .tags and f .func .kind () == SchemaKind .out :
607
605
return None , False
@@ -655,7 +653,7 @@ def find_info(
655
653
656
654
return None , False
657
655
658
- result : list [NativeFunctionWithDifferentiabilityInfo ] = []
656
+ result : List [NativeFunctionWithDifferentiabilityInfo ] = []
659
657
for f in native_functions :
660
658
info_dict , is_exact_match = find_info (f )
661
659
@@ -679,7 +677,7 @@ def find_info(
679
677
)
680
678
continue
681
679
682
- fw_derivative_dict : dict [str , Sequence [ForwardDerivative ]] = {}
680
+ fw_derivative_dict : Dict [str , Sequence [ForwardDerivative ]] = {}
683
681
for key , info in info_dict .items ():
684
682
if not info .forward_derivatives :
685
683
fw_derivative_dict [key ] = []
@@ -715,7 +713,7 @@ def find_info(
715
713
formula = fw_info .formula
716
714
717
715
def replace_self_with_original_self (formula : str , postfix : str ) -> str :
718
- def repl (m : re . Match [str ]) -> str :
716
+ def repl (m : Match [str ]) -> str :
719
717
return f"{ m .group (1 )} original_self{ postfix } { m .group (2 )} "
720
718
721
719
return re .sub (IDENT_REGEX .format (f"self{ postfix } " ), repl , formula )
@@ -736,7 +734,7 @@ def repl(m: re.Match[str]) -> str:
736
734
formula = replace_self_with_original_self (formula , "_t" )
737
735
738
736
# replace "result" from the formula by "self_p"
739
- def repl (m : re . Match [str ]) -> str :
737
+ def repl (m : Match [str ]) -> str :
740
738
return f"{ m .group (1 )} self_p{ m .group (2 )} "
741
739
742
740
formula = re .sub (IDENT_REGEX .format ("result" ), repl , formula )
@@ -760,8 +758,8 @@ def repl(m: re.Match[str]) -> str:
760
758
# If there is a need, we can relax (2) to allow any op that has an in-place variant
761
759
is_single_method_on_self_t = False
762
760
directly_do_inplace = False
763
- op_name : str | None = None
764
- between_parens : str | None = None
761
+ op_name : Optional [ str ] = None
762
+ between_parens : Optional [ str ] = None
765
763
match = re .fullmatch (r"self_t.([\w]*)\((.*)\)" , formula )
766
764
if match :
767
765
op_name , between_parens = match .group (1 ), match .group (2 )
@@ -825,7 +823,7 @@ def check_parens_nest_level_gt_zero(s: str) -> bool:
825
823
826
824
827
825
def is_differentiable (
828
- name : str , type : Type , info : DifferentiabilityInfo | None
826
+ name : str , type : Type , info : Optional [ DifferentiabilityInfo ]
829
827
) -> bool :
830
828
return type .is_tensor_like () and (
831
829
info is None or name not in info .non_differentiable_arg_names
@@ -834,10 +832,10 @@ def is_differentiable(
834
832
835
833
def gen_differentiable_outputs (
836
834
fn : NativeFunctionWithDifferentiabilityInfo , key : str = "Default"
837
- ) -> list [DifferentiableOutput ]:
835
+ ) -> List [DifferentiableOutput ]:
838
836
f = fn .func
839
837
info = fn .info [key ] if fn .info else None
840
- outputs : list [DifferentiableOutput ] = [
838
+ outputs : List [DifferentiableOutput ] = [
841
839
DifferentiableOutput (
842
840
name = name ,
843
841
type = ret .type ,
@@ -852,7 +850,7 @@ def gen_differentiable_outputs(
852
850
f"The length of output_differentiability ({ len (output_differentiability )} ), "
853
851
f"does not match the number of outputs ({ len (outputs )} )."
854
852
)
855
- differentiable_outputs : list [DifferentiableOutput ] = []
853
+ differentiable_outputs : List [DifferentiableOutput ] = []
856
854
if False in output_differentiability and f .func .kind () == SchemaKind .inplace :
857
855
raise RuntimeError (
858
856
"output_differentiability=False for inplace operation (version_counter won't get updated)"
0 commit comments