[DTensor] Partial(sum)
reductions are wrongly cached (?)
#147180
Labels
module: dtensor
distributed tensor tag
oncall: distributed
Add this issue/PR to distributed oncall triage queue
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🐛 Describe the bug
First of all, a very simple motivating example:
In the above code, we
Partial()
DTensor from different local randn tensors.-5*dt
and2*dt
are not the same when their work is replicated-5*dt
and2*dt
return the same result (???) when thePartial()
dt
is used.If we print out the values involved, the issue becomes more clear:
Somehow, the result of
-5*dt
is cached and reused as the return value for2*dt
...The following also returns true:
I do not know how to debug what is happening further.
Versions
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l @XilunWu
The text was updated successfully, but these errors were encountered: