8000 Delete sections referencing torchscript in serialization docs by mikaylagawarecki · Pull Request #156648 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 0 additions & 166 deletions docs/source/notes/serialization.rst
8000
Original file line number Diff line number Diff line change
Expand Up @@ -339,172 +339,6 @@ if one does not have access to the ``torch.load`` callsites.
if ``weights_only`` was not passed as an argument.


.. _serializing-python-modules:

Serializing torch.nn.Modules and loading them in C++
----------------------------------------------------

See also: `Tutorial: Loading a TorchScript Model in C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`_

ScriptModules can be serialized as a TorchScript program and loaded
using :func:`torch.jit.load`.
This serialization encodes all the modules’ methods, submodules, parameters,
and attributes, and it allows the serialized program to be loaded in C++
(i.e. without Python).

The distinction between :func:`torch.jit.save` and :func:`torch.save` may not
be immediately clear. :func:`torch.save` saves Python objects with pickle.
This is especially useful for prototyping, researching, and training.
:func:`torch.jit.save`, on the other hand, serializes ScriptModules to a format
that can be loaded in Python or C++. This is useful when saving and loading C++
modules or for running modules trained in Python with C++, a common practice
when deploying PyTorch models.

To script, serialize and load a module in Python:

::

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
(l0): RecursiveScriptModule(original_name=Linear)
(l1): RecursiveScriptModule(original_name=Linear) )


Traced modules can also be saved with :func:`torch.jit.save`, with the caveat
that only the traced code path is serialized. The following example demonstrates
this:

::

# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)

def forward(self, input):
if input.dim() > 1:
return torch.tensor(0)

out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

The above module has an if statement that is not triggered by the traced inputs,
and so is not part of the traced module and not serialized with it.
The scripted module, however, contains the if statement and is serialized with it.
See the `TorchScript documentation <https://pytorch.org/docs/stable/jit.html>`_
for more on scripting and tracing.

Finally, to load the module in C++:

::

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

See the `PyTorch C++ API documentation <https://pytorch.org/cppdocs/>`_
for details about how to use PyTorch modules in C++.

.. _saving-loading-across-versions:

Saving and loading ScriptModules across PyTorch versions
-----------------------------------------------------------

The PyTorch Team recommends saving and loading modules with the same version of
PyTorch. Older versions of PyTorch may not support newer modules, and newer
versions may have removed or modified older behavior. These changes are
explicitly described in
PyTorch’s `release notes <https://github.com/pytorch/pytorch/releases>`_,
and modules relying on functionality that has changed may need to be updated
to continue working properly. In limited cases, detailed below, PyTorch will
preserve the historic behavior of serialized ScriptModules so they do not require
an update.

torch.div performing integer division
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In PyTorch 1.5 and earlier :func:`torch.div` would perform floor division when
given two integer inputs:

::

# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

In PyTorch 1.7, however, :func:`torch.div` will always perform a true division
of its inputs, just like division in Python 3:

::

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

The behavior of :func:`torch.div` is preserved in serialized ScriptModules.
That is, ScriptModules serialized with versions of PyTorch before 1.6 will continue
to see :func:`torch.div` perform floor division when given two integer inputs
even when loaded with newer versions of PyTorch. ScriptModules using :func:`torch.div`
and serialized on PyTorch 1.6 and later cannot be loaded in earlier versions of
PyTorch, however, since those earlier versions do not understand the new behavior.

torch.full always inferring a float dtype
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In PyTorch 1.5 and earlier :func:`torch.full` always returned a float tensor,
regardless of the fill value it’s given:

::

# PyTorch 1.5 and earlier
>>> torch.full((3,), 1) # Note the integer fill value...
tensor([1., 1., 1.]) # ...but float tensor!

In PyTorch 1.7, however, :func:`torch.full` will infer the returned tensor’s
dtype from the fill value:

::

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1 64B1 .+1.j, 1.+1.j])

The behavior of :func:`torch.full` is preserved in serialized ScriptModules. That is,
ScriptModules serialized with versions of PyTorch before 1.6 will continue to see
torch.full return float tensors by default, even when given bool or
integer fill values. ScriptModules using :func:`torch.full` and serialized on PyTorch 1.6
and later cannot be loaded in earlier versions of PyTorch, however, since those
earlier versions do not understand the new behavior.

.. _utility functions:

Utility functions
Expand Down
Loading
0