File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -209,6 +209,22 @@ def add_safe_globals(safe_globals: List[Any]) -> None:
209
209
210
210
Args:
211
211
safe_globals (List[Any]): list of globals to mark as safe
212
+
213
+ Example:
214
+ >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
215
+ >>> import tempfile
216
+ >>> class MyTensor(torch.Tensor):
217
+ ... pass
218
+ >>> t = MyTensor(torch.randn(2, 3))
219
+ >>> with tempfile.NamedTemporaryFile() as f:
220
+ ... torch.save(t, f.name)
221
+ # Running `torch.load(f.name, weights_only=True)` will fail with
222
+ # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
223
+ # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
224
+ ... torch.serialization.add_safe_globals([MyTensor])
225
+ ... torch.load(f.name, weights_only=True)
226
+ # MyTensor([[-0.5024, -1.8152, -0.5455],
227
+ # [-0.8234, 2.0500, -0.3657]])
212
228
"""
213
229
_weights_only_unpickler ._add_safe_globals (safe_globals )
214
230
You can’t perform that action at this time.
0 commit comments