8000 Add example for torch.serialization.add_safe_globals · pytorch/pytorch@eba6f42 · GitHub
[go: up one dir, main page]

Skip to content

Commit eba6f42

Browse files
Add example for torch.serialization.add_safe_globals
ghstack-source-id: e23d66f Pull Request resolved: #129396
1 parent ae20e92 commit eba6f42

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

torch/serialization.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,22 @@ def add_safe_globals(safe_globals: List[Any]) -> None:
209209
210210
Args:
211211
safe_globals (List[Any]): list of globals to mark as safe
212+
213+
Example:
214+
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
215+
>>> import tempfile
216+
>>> class MyTensor(torch.Tensor):
217+
... pass
218+
>>> t = MyTensor(torch.randn(2, 3))
219+
>>> with tempfile.NamedTemporaryFile() as f:
220+
... torch.save(t, f.name)
221+
# Running `torch.load(f.name, weights_only=True)` will fail with
222+
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
223+
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
224+
... torch.serialization.add_safe_globals([MyTensor])
225+
... torch.load(f.name, weights_only=True)
226+
# MyTensor([[-0.5024, -1.8152, -0.5455],
227+
# [-0.8234, 2.0500, -0.3657]])
212228
"""
213229
_weights_only_unpickler._add_safe_globals(safe_globals)
214230

0 commit comments

Comments
 (0)
0