[DTensor] Scalar multiplication after reduction doesn't update result without calling .full_tensor() before #153603
Labels
oncall: distributed
Add this issue/PR to distributed oncall triage queue
🐛 Describe the bug
When a DTensor undergoes a reduction (e.g., using
.mean()
or.sum()
), and the resulting DTensor is then multiplied by a scalar multiplier, the calculation yields an incorrect result if.full_tensor()
is not called on the reduced DTensor before the multiplication. It appears that the result of the first such multiplication is cached and incorrectly reused in subsequent multiplications with different scalars. This issue does not occur if the multiplier is a tensor instead of a scalar.Versions
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k
The text was updated successfully, but these errors were encountered: