8000 Module.to should not be differentiable · Issue #1148 · dotnet/TorchSharp · GitHub
[go: up one dir, main page]

Skip to content
Module.to should not be differentiable #1148
@zgxnet

Description

@zgxnet

Current Module.to implementation keeps the moved tensors in the same graph.

                        case Parameter param: {
                                var t = param.to(dtype, device); //in the same graph
                                t.retain_grad();
                                ...
                            }

                        case Tensor tensor when (device.type != tensor.device_type || device.index != tensor.device_index): {
                                var t = tensor.to(dtype, device); //in the same graph
                                ...
                            }

While in the pytorch implementation, the moved tensors are detached.

        for key, param in self._parameters.items():
            if param is None:
                continue
            # Tensors stored in modules are graph leaves, and we don't want to
            # track autograd history of `param_applied`, so we have to use
            # `with torch.no_grad():`
            with torch.no_grad():  #detached from the original tensor
                param_applied = fn(param)
                ...
            if param.grad is not None:
                with torch.no_grad(): #detached from the original tensor
                    grad_applied = fn(param.grad)
                ...

As the module tensors should be leaf tensors, it makes no sense to keep the original tensors. If a CUDA module is not detached from the original CPU tensors, there are unnecessary gradient traffics from the GPU to the CPU in very backpropagation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0