torch.compiler.disable()
on module hooks will disable module.compile()
#142358
Labels
module: dynamo
module: nn
Related to torch.nn
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
torch.compiler.disable()
on module hooks will disable compile on the whole module whenmodule.compile()
is used (or viatorch.compile(some_function)(model, some_inputs)
)To check if the model is compiled, I use
TORCH_LOGS="output_code"
. Formodel.compile()
, there is no code output. Formodel.forward = torch.compile(model.forward)
, there is compiled triton code as expected.Although
model.forward = torch.compile(model.forward)
is a valid workaround, it prevents me from doing full-compile a model in a function e.g. compute loss with signaturecompute_loss(model, inputs, labels)
. Again, for this case, a workaround is to isolate the loss computation logic and compile it separately, but it gets messy. Ideally torch.compile should only ignore the hooks, not the.forward()
logic, correctly.Error logs
No response
Versions
torch==2.6.0.dev20241208+cu126
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames
The text was updated successfully, but these errors were encountered: