|
85 | 85 |
|
86 | 86 | IS_WINDOWS = sys.platform == "win32"
|
87 | 87 |
|
| 88 | +UNSAFE_MESSAGE = ( |
| 89 | + "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` " |
| 90 | + "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " |
| 91 | + "but it can result in arbitrary code execution. Do it only if you got the file from a " |
| 92 | + "trusted source." |
| 93 | +) |
| 94 | + |
88 | 95 | if not IS_WINDOWS:
|
89 | 96 | from mmap import MAP_PRIVATE, MAP_SHARED
|
90 | 97 | else:
|
@@ -1341,12 +1348,6 @@ def load(
|
1341 | 1348 | >>> torch.load("module.pt", encoding="ascii", weights_only=False)
|
1342 | 1349 | """
|
1343 | 1350 | torch._C._log_api_usage_once("torch.load")
|
1344 |
| - UNSAFE_MESSAGE = ( |
1345 |
| - "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` " |
1346 |
| - "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " |
1347 |
| - "but it can result in arbitrary code execution. Do it only if you got the file from a " |
1348 |
| - "trusted source." |
1349 |
| - ) |
1350 | 1351 | DOCS_MESSAGE = (
|
1351 | 1352 | "\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
|
1352 | 1353 | "weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
|
@@ -1611,6 +1612,11 @@ def persistent_load(saved_id):
|
1611 | 1612 | with closing(
|
1612 | 1613 | tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
|
1613 | 1614 | ) as tar, mkdtemp() as tmpdir:
|
| 1615 | + if pickle_module is _weights_only_unpickler: |
| 1616 | + raise RuntimeError( |
| 1617 | + "Cannot use ``weights_only=True`` with files saved in the " |
| 1618 | + "legacy .tar format. " + UNSAFE_MESSAGE |
| 1619 | + ) |
1614 | 1620 | tar.extract("storages", path=tmpdir)
|
1615 | 1621 | with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
|
1616 | 1622 | num_storages = pickle_module.load(f, **pickle_load_args)
|
|
0 commit comments