8000 [caffe2/tools/autograd] Fix non-determinism in code gen (#101425) · pytorch/pytorch@799ef7e · GitHub
[go: up one dir, main page]

Skip to content

Commit 799ef7e

Browse files
andrewjcgpytorchmergebot
authored andcommitted
[caffe2/tools/autograd] Fix non-determinism in code gen (#101425)
Fix several cases of leaking set-iteration-order to generated sources, causing non-determinism in generated code. Pull Request resolved: #101425 Approved by: https://github.com/albanD
1 parent a837609 commit 799ef7e

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

tools/autograd/gen_autograd_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,9 @@ def save_var(var: SavedAttribute, is_output: bool) -> None:
745745
PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
746746
)
747747

748-
for var in info.all_saved_inputs:
748+
for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
749749
save_var(var, is_output=False)
750-
for var in info.all_saved_outputs:
750+
for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)):
751751
save_var(var, is_output=True)
752752

753753
# lock the mutex when we release variables and in Node::apply to protect thread safety
@@ -770,7 +770,7 @@ def save_var(var: SavedAttribute, is_output: bool) -> None:
770770
# Generate aliases for gradients named for returned values.
771771
body.extend(
772772
f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];"
773-
for name in info.used_named_gradients
773+
for name in sorted(info.used_named_gradients)
774774
)
775775

776776
def emit_derivative(

tools/autograd/gen_python_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,9 @@ def gen(
365365
def gen_tags_enum() -> Dict[str, str]:
366366
return {
367367
"enum_of_valid_tags": (
368-
"".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])
368+
"".join(
369+
[f'\n.value("{tag}", at::Tag::{tag})' for tag in sorted(valid_tags)]
370+
)
369371
)
370372
}
371373

tools/autograd/gen_variable_type.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def gen_variable_type(
818818
# dispatch key that appears in derivatives.yaml
819819
def wrapper_registrations(used_keys: Set[str]) -> str:
820820
library_impl_macro_list: List[str] = []
821-
for key in used_keys:
821+
for key in sorted(used_keys):
822822
dispatch_key = key
823823
if key == "Default":
824824
dispatch_key = "Autograd"
@@ -843,7 +843,7 @@ def wrapper_registrations(used_keys: Set[str]) -> str:
843843
"type_derived_method_definitions": "\n\n".join(
844844
[
845845
"${" + f"type_derived_method_definitions_{key}" + "}"
846-
for key in used_keys
846+
for key in sorted(used_keys)
847847
]
848848
),
849849
"wrapper_registrations": wrapper_registrations(used_keys),
@@ -854,8 +854,8 @@ def wrapper_registrations(used_keys: Set[str]) -> str:
854854
fm2 = FileManager(install_dir=out, template_dir=out + "/templates", dry_run=False)
855855

856856
sharded_keys = set(
857-
[f"type_derived_method_definitions_{key}" for key in used_keys]
858-
+ [f"wrapper_registrations_{key}" for key in used_keys]
857+
[f"type_derived_method_definitions_{key}" for key in sorted(used_keys)]
858+
+ [f"wrapper_registrations_{key}" for key in sorted(used_keys)]
859859
)
860860
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
861861
# template regarding sharding of the generated files.
@@ -1337,7 +1337,7 @@ def save_variables(
13371337
) -> Sequence[str]:
13381338
# assign the saved variables to the generated grad_fn
13391339
stmts: List[str] = []
1340-
for arg in saved_variables:
1340+
for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
13411341
name = (
13421342
arg.nctype.name.name
13431343
if isinstance(arg.nctype.name, SpecialArgName)

0 commit comments

Comments
 (0)
0