Loss parallel's override of log_softmax doesn't support negative dims #152016
Labels
oncall: distributed
Add this issue/PR to distributed oncall triage queue
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🐛 Describe the bug
It's allowed to invoke
F.log_softmax
with a negative dimension, such as -1. However, when enabling the loss parallel context manager, the log-softmax op gets overridden with a custom impl which seems to require that the dim be positive.pytorch/torch/distributed/tensor/parallel/loss.py
Line 473 in b32b002
Versions
N/A
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k
The text was updated successfully, but these errors were encountered: