8000 [dynamo] pack resume function stack + locals into a list · pytorch/pytorch@f2591ce · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit f2591ce

Browse files
committed
[dynamo] pack resume function stack + locals into a list
ghstack-source-id: b9185f0 Pull Request resolved: #151056
1 parent de7fb5f commit f2591ce

File tree

2 files changed

+28
-82
lines changed

2 files changed

+28
-82
lines changed

torch/_dynamo/resume_execution.py

Lines changed: 20 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,12 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]):
345345
code_options["co_firstlineno"] = lineno
346346
code_options["co_cellvars"] = ()
347347
code_options["co_freevars"] = freevars
348-
code_options["co_argcount"] = len(args)
348+
code_options["co_argcount"] = 1
349349
code_options["co_posonlyargcount"] = 0
350350
code_options["co_kwonlyargcount"] = 0
351351
code_options["co_varnames"] = tuple(
352-
args
352+
["__frame_values"]
353+
+ args
353354
+ [v for v in argnames_null if v not in args]
354355
+ [
355356
v
@@ -370,6 +371,23 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]):
370371
)
371372
prefix.append(create_instruction("RESUME", arg=0))
372373

374+
# load this frame's values (stack + locals) to the right places
375+
load_frame_values_source = "def load_frame_values(__frame_values):\n"
376+
if len(args) == 0:
377+
# load the value anyway so symbolic_convert can pop this local internally
378+
load_frame_values_source += " __frame_values"
379+
elif len(args) == 1:
380+
load_frame_values_source += f" {args[0]} = __frame_values[0]"
381+
else:
382+
load_frame_values_source += f" {', '.join(args)} = __frame_values"
383+
384+
load_frame_values_locals: dict[str, Any] = {}
385+
exec(load_frame_values_source, None, load_frame_values_locals)
386+
load_frame_values_bytecode = bytecode_from_template(
387+
load_frame_values_locals["load_frame_values"],
388+
)
389+
prefix.extend(load_frame_values_bytecode)
390+
373391
cleanup: list[Instruction] = []
374392
hooks = {fn.stack_index: fn for fn in setup_fns}
375393
hook_target_offsets = {
@@ -567,78 +585,3 @@ def remap_block_offsets(
567585
return ContinueExecutionCache.lookup(
568586
meta.code, lineno, new_offset, setup_fn_target_offsets, *args
569587
)
570-
571-
572-
"""
573-
# partially finished support for with statements
574-
575-
def convert_locals_to_cells(
576-
instructions: List[Instruction],
577-
code_options: Dict[str, Any]):
578-
579-
code_options["co_cellvars"] = tuple(
580-
var
581-
for var in code_options["co_varnames"]
582-
if var not in code_options["co_freevars"]
583-
and not var.startswith("___stack")
584-
)
585-
cell_and_free = code_options["co_cellvars"] + code_options["co_freevars"]
586-
for inst in instructions:
587-
if str(inst.argval).startswith("___stack"):
588-
continue
589-
elif inst.opname == "LOAD_FAST":
590-
inst.opname = "LOAD_DEREF"
591-
elif inst.opname == "STORE_FAST":
592-
inst.opname = "STORE_DEREF"
593-
elif inst.opname == "DELETE_FAST":
594-
inst.opname = "DELETE_DEREF"
595-
else:
596-
continue
597-
inst.opcode = dis.opmap[inst.opname]
598-
assert inst.argval in cell_and_free, inst.argval
599-
inst.arg = cell_and_free.index(inst.argval)
600-
601-
def patch_setup_with(
602-
instructions: List[Instruction],
603-
code_options: Dict[str, Any]
604-
):
605-
nonlocal need_skip
606-
need_skip = True
607-
target_index = next(
608-
idx for idx, i in enumerate(instructions) if i.offset == offset
609-
)
610-
assert instructions[target_index].opname == "SETUP_WITH"
611-
convert_locals_to_cells(instructions, code_options)
612-
613-
stack_depth_before = nstack + stack_effect(instructions[target_index].opcode,
614-
instructions[target_index].arg)
615-
616-
inside_with = []
617-
inside_with_resume_at = None
618-
stack_depth = stack_depth_before
619-
idx = target_index + 1
620-
for idx in range(idx, len(instructions)):
621-
inst = instructions[idx]
622-
if inst.opname == "BEGIN_FINALLY":
623-
inside_with_resume_at = inst
624-
break
625-
elif inst.target is not None:
626-
unimplemented("jump from with not supported")
627-
elif inst.opname in ("BEGIN_FINALLY", "WITH_CLEANUP_START", "WITH_CLEANUP_FINISH", "END_FINALLY",
628-
"POP_FINALLY", "POP_EXCEPT",
629-
"POP_BLOCK", "END_ASYNC_FOR"):
630-
unimplemented("block ops not supported")
631-
inside_with.append(inst)
632-
stack_depth += stack_effect(inst.opcode, inst.arg)
633-
assert inside_with_resume_at
634-
635-
instructions = [
636-
create_instruction("LOAD_FAST", f"___stack{i}") for i in range(nstack)
637-
] + [
638-
create_instruction("SETUP_WITH", target=instructions[target_index].target)
639-
... call the function ...
640-
unpack_tuple
641-
] + [
642-
create_instruction("JUMP_ABSOLUTE", target=inside_with_resume_at)
643-
]
644-
"""

torch/_dynamo/symbolic_convert.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,7 +1424,7 @@ def LOAD_FAST(self, inst):
14241424
)
14251425

14261426
# for continuation functions
1427-
if name.startswith("___stack"):
1427+
if name == "__frame_values" or name.startswith("__stack"):
14281428
self.symbolic_locals.pop(name)
14291429

14301430
def LOAD_DEREF(self, inst):
@@ -3633,19 +3633,22 @@ def create_call_resume_at(self, inst):
36333633
orig_graphmodule_maybe
36343634
)
36353635

3636+
# load stack and values into a single list
3637+
cg.extend_output([cg.create_load(k) for k in argnames])
3638+
cg.append_output(create_instruction("BUILD_LIST", arg=nargs))
3639+
36363640
if new_code.co_freevars:
36373641
# expose code object for debugging purposes
36383642
self.output.install_global_unsafe(name, new_code)
3639-
cg.make_function_with_closure(name, new_code, True, stack_len)
3643+
cg.make_function_with_closure(name, new_code, True, 1)
36403644
else:
36413645
# This is safe: we pre-generate a unique name
36423646
self.output.install_global_unsafe(
36433647
name, types.FunctionType(new_code, self.f_globals, name)
36443648
)
3645-
cg.extend_output(cg.load_function_name(name, True, stack_len))
3649+
cg.extend_output(cg.load_function_name(name, True, 1))
36463650

3647-
cg.extend_output([cg.create_load(k) for k in argnames])
3648-
cg.extend_output(create_call_function(nargs, False))
3651+
cg.extend_output(create_call_function(1, False))
36493652
cg.append_output(create_instruction("RETURN_VALUE"))
36503653
return cg.get_instructions()
36513654

0 commit comments

Comments
 (0)
0