-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
I've had a number of discussions with folks about the idea of using TorchDynamo to capture data-dependent control flow into some sort of (yet to be created) control flow operator in PyTorch/aten.
Technically, capturing the control flow wouldn't be that hard. TorchDynamo sees the POP_JUMP_IF_FALSE <tensor>
bytecode, and decides to break the graph on it. It could instead decide to inject an aten::if
or aten::bailout_if
into the graph.
The real challenge is how do you trace the remainder of the function. The easiest approach would be to follow a tracing JIT pattern and only follow the branch taken with the initial example inputs TorchDynamo already tracks. This would result in an aten::bailout_if
op, that would allow early exit from the graph if the execution diverges from the original trace.
Other options that capture multiple sides of the branch would also be possible, but more complex to implement.
I'd be curious how well backend would do with these types of operators. Currently, I don't think we do fusions that cross control flow boundaries, so the speedups from this could be limited.
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @soumith @wconstab @ngimel @mlazos @yanboliang @Xia-Weiwen @desertfire
Metadata
Metadata
Assignees
Labels
Type
Projects
Status