-
Notifications
You must be signed in to change notification settings - Fork 204
Closed
Description
Current Module.to implementation keeps the moved tensors in the same graph.
case Parameter param: {
var t = param.to(dtype, device); //in the same graph
t.retain_grad();
...
}
case Tensor tensor when (device.type != tensor.device_type || device.index != tensor.device_index): {
var t = tensor.to(dtype, device); //in the same graph
...
}
While in the pytorch implementation, the moved tensors are detached.
for key, param in self._parameters.items():
if param is None:
continue
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad(): #detached from the original tensor
param_applied = fn(param)
...
if param.grad is not None:
with torch.no_grad(): #detached from the original tensor
grad_applied = fn(param.grad)
...
As the module tensors should be leaf tensors, it makes no sense to keep the original tensors. If a CUDA module is not detached from the original CPU tensors, there are unnecessary gradient traffics from the GPU to the CPU in very backpropagation.
Metadata
Metadata
Assignees
Labels
No labels