-
Notifications
You must be signed in to change notification settings - Fork 24.3k
improved serialization (no tar copy) #713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
cdc97f8
to
7b4ec7f
Compare
test/test_torch.py
Outdated
@@ -2497,6 +2500,7 @@ def test_serialization(self): | |||
b = [a[i % 2] for i in range(4)] | |||
b += [a[0].storage()] | |||
b += [a[0].storage()[1:4]] | |||
b + [torch.range(1, 10).int()] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -190,13 +190,33 @@ PyObject * THPStorage_(newWithFile)(PyObject *_unused, PyObject *file) | |||
int fd = PyObject_AsFileDescriptor(file); | |||
THPUtils_assert(fd != -1, "_new_with_file couldn't retrieve a file " | |||
"descriptor from given object"); | |||
THStoragePtr storage = THPStorage_(readFileRaw)(fd); | |||
THStoragePtr storage = THPStorage_(readFileRaw)(fd, nullptr); | |||
PyObject *result = THPStorage_(New)(storage); | |||
storage.release(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
PyObject *file = PyTuple_GET_ITEM(args, 0); | ||
int fd = PyObject_AsFileDescriptor(file); | ||
|
||
PyObject *offset = PyTuple_GET_ITEM(args, 1); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/serialization.py
Outdated
serialized_storages[root_key] = root | ||
is_view = obj._cdata != root._cdata | ||
|
||
return ('storage', storage_type, root_key, location, root.size(), is_view, offset, obj.size()) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/serialization.py
Outdated
else: | ||
raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) | ||
|
||
# try the legacy loader first, which only works if f is a tarfile |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if magic_number != MAGIC_NUMBER: | ||
raise RuntimeError("Invalid magic number; corrupt file?") | ||
protocol_version = pickle_module.load(f) | ||
if protocol_version != PROTOCOL_VERSION: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/serialization.py
Outdated
data_type(size), location) | ||
data = deserialized_storages[key] | ||
if is_view: | ||
data = data[offset:offset + view_size] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/serialization.py
Outdated
if key not in deserialized_storages: | ||
deserialized_storages[key] = restore_location( | ||
data_type(size), location) | ||
data = deserialized_storages[key] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
def persistent_id(obj): | ||
# FIXME: the docs say that persistent_id should only return a string |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
7b4ec7f
to
82576ba
Compare
if magic_number != MAGIC_NUMBER: | ||
raise RuntimeError("Invalid magic number; corrupt file?") | ||
protocol_version = pickle_module.load(f) | ||
if protocol_version != PROTOCOL_VERSION: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
b += [torch.range(1, 10).int()] | ||
t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) | ||
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) | ||
b += [(t1.storage(), t1.storage(), t2.storage())] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
views = c[7] | ||
self.assertEqual(views[0]._cdata, views[1]._cdata) | ||
self.assertEqual(views[0], views[2]) | ||
self.assertNotEqual(views[0]._cdata, views[2]._cdata) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
8000
torch/serialization.py
Outdated
_check_container_source(*data) | ||
return data[0] | ||
elif typename == 'storage': | ||
data_type, key, location, size, is_view, view_metadata = data |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/serialization.py
Outdated
deserialized_objects[key] = restore_location( | ||
data_type(size), location) | ||
storage = deserialized_objects[key] | ||
if is_view: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/tensor.py
Outdated
tuple(self.size()), | ||
self.stride())) | ||
|
||
def __setstate__(self, state): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
f63675b
to
8b22927
Compare
@soumith please don't merge yet. |
ef93662
to
18c778a
Compare
Okay, I think it's ready for merge now. P.S. as of this change, .pt files are no longer compressed, I've seen this increase their size by 2x in cases where the data is compressible (e.g. LongTensor with small ints). |
Huh? I don't think they're compressed now, so the file sizes should be quite similar. |
18c778a
to
81883fc
Compare
torch/csrc/generic/serialization.cpp
Outdated
} else { | ||
int64_t buffer_size = std::min(self->size, (long)5000); | ||
long buffer_size = std::min(size, (long)5000); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
81883fc
to
a1c2819
Compare
views = c[7] | ||
self.assertEqual(views[0]._cdata, views[1]._cdata) | ||
self.assertEqual(views[0], views[2]) | ||
self.assertNotEqual(views[0]._cdata, views[2]._cdata) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
while (remaining > 0) { | ||
ssize_t result = write(fd, bytes, remaining); | ||
if (result < 0) | ||
throw std::system_error(result, std::system_category()); | ||
bytes += result; | ||
remaining -= result; | ||
} | ||
if (remaining != 0) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/generic/serialization.cpp
Outdated
} else { | ||
int64_t buffer_size = std::min(self->size, (long)5000); | ||
int64_t buffer_size = std::min(size, (long)5000); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
throw std::system_error(result, std::system_category()); | ||
bytes += result; | ||
remaining -= result; | ||
} | ||
if (remaining != 0) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
fix indexing bug in sampleMultinomialOnce
…9c90c8 Previous import was a4dcc47791eb127652f5aaddd51d8896d446a067 Included changes: - **[985af3f](onnx/onnx@985af3f)**: Update PythonAPIOverview.md (pytorch#738) <Dmytro Dzhulgakov> - **[b69be33](onnx/onnx@b69be33)**: Add backend test for upsample (pytorch#729) <Sebastian Meßmer> - **[0d9496e](onnx/onnx@0d9496e)**: Input test data of concat op should be float (pytorch#711) <Changming Sun> - **[20bcb8b](onnx/onnx@20bcb8b)**: Fix the spec for batchnorm and instancenorm (pytorch#733) <Lu Fang> - **[c9f825f](onnx/onnx@c9f825f)**: Refine a little bit about op spec. (pytorch#666) <Ke Zhang> - **[a484eb2](onnx/onnx@a484eb2)**: Fix an error in Conv doc (pytorch#731) <Lu Fang> - **[7410cc4](onnx/onnx@7410cc4)**: Fix incorrect package output paths (pytorch#730) <bddppq> - **[be546e2](onnx/onnx@be546e2)**: Improve optimizer's API and docs (pytorch#713) <Lu Fang> - **[c61506f](onnx/onnx@c61506f)**: Fix the shape inference python API (pytorch#716) <Lu Fang> - **[e9d4134](onnx/onnx@e9d4134)**: Fix cmake on windows when not building python extension (pytorch#728) <bddppq> - **[72187aa](onnx/onnx@72187aa)**: Add value_info support in make_graph (pytorch#726) <Lu Fang> - **[67b7d89](onnx/onnx@67b7d89)**: Fix gen_proto in cmake (pytorch#719) <bddppq> - **[fcb4ae3](onnx/onnx@fcb4ae3)**: docs rewording: Important Python Functions -> Python API Overview (pytorch#721) <anderspapitto> - **[24275d6](onnx/onnx@24275d6)**: Ignore .eggs directory when doing lint (pytorch#722) <bddppq> - **[54be8fa](onnx/onnx@54be8fa)**: Use cmake3 if it's available (pytorch#718) <bddppq> - **[b8c4238](onnx/onnx@b8c4238)**: Add python function docs (pytorch#714) <Lu Fang> - **[e177493](onnx/onnx@e177493)**: Remove unused cmake utils (pytorch#712) <bddppq> - **[72d6ad6](onnx/onnx@72d6ad6)**: Remove pycmd from CMake (pytorch#710) <bddppq> - **[93f0d40](onnx/onnx@93f0d40)**: Fix windows local build (pytorch#709) <Raymond Yang> - **[6734224](onnx/onnx@6734224)**: CMake fixes and setup.py cleanup (pytorch#706) <bddppq> - **[7f6a4fd](onnx/onnx@7f6a4fd)**: Add docs to explain important functions in ONNX Infra (pytorch#682) <Lu Fang> - **[f0f6b3d](onnx/onnx@f0f6b3d)**: fix hardmax test cases make output dtype same as input (pytorch#705) <Wenhao Hu> - **[c970f0c](onnx/onnx@c970f0c)**: Fix the Dummy backend (pytorch#701) <Lu Fang> - **[2af45df](onnx/onnx@2af45df)**: setup.py uses cmake build system (pytorch#606) <anderspapitto> - **[dfcaade](onnx/onnx@dfcaade)**: clean up unused variable left by removing consumed_input (pytorch#697) <bddppq> - **[accfc74](onnx/onnx@accfc74)**: Remove incorrect backend test (pytorch#700) <Lu Fang> - **[e558732](onnx/onnx@e558732)**: add max inclusive version to defs.get_schema function (pytorch#695) <Wenhao Hu> - **[16f02eb](onnx/onnx@16f02eb)**: add API to add domain to min/max version for extension. (pytorch#694) <Ke Zhang> - **[3e560dd](onnx/onnx@3e560dd)**: Fix doc for initializer (pytorch#690) <bddppq> - **[6cc4f53](onnx/onnx@6cc4f53)**: Add model save function (pytorch#692) <Lu Fang> - **[21eaf9b](onnx/onnx@21eaf9b)**: Changing the string discussing versions in operator specifications. (pytorch#691) <Niklas Gustafsson> - **[3b0cdf4](onnx/onnx@3b0cdf4)**: Minor code quality improvements in optimizer/ (pytorch#612) <Sebastian Meßmer> - **[641f126](onnx/onnx@641f126)**: Fix Gemm doc wording (pytorch#689) <bddppq> - **[4a0ec75](onnx/onnx@4a0ec75)**: Clarifies installation error message when external protobuf dependencies are missing (pytorch#684) <Daniel J. H> - **[960a2c3](onnx/onnx@960a2c3)**: Check outputs dtype in backend tests (pytorch#567) <bddppq> - **[1d7dee4](onnx/onnx@1d7dee4)**: Fix Average pool test cases converted from PyTorch (pytorch#677) <Lu Fang> - **[36d7fff](onnx/onnx@36d7fff)**: Fix Attribute default value pybind11 binding (pytorch#671) <bddppq> - **[0536866](onnx/onnx@0536866)**: git ignore .pytest_cache (pytorch#674) <bddppq> - **[afc84ac](onnx/onnx@afc84ac)**: Update README.md (pytorch#672) <Dmytro Dzhulgakov> - **[9d2b530](onnx/onnx@9d2b530)**: Revert "[Typing 1/3] Setup mypy type checker (pytorch#607)" (pytorch#667) <bddppq> - **[086727e](onnx/onnx@086727e)**: [Typing 1/3] Setup mypy type checker (pytorch#607) <Sebastian Meßmer> - **[5716e20](onnx/onnx@5716e20)**: Convert all Node tests to Model tests (pytorch#651) <bddppq> - **[6fe932a](onnx/onnx@6fe932a)**: Replace unittest.skip with custom exception (pytorch#659) <Dmytro Dzhulgakov> - **[ecac1c1](onnx/onnx@ecac1c1)**: Merge Rel 1.1.0 branch into master (pytorch#657) <Anirudh> - **[5cb999d](onnx/onnx@5cb999d)**: Minor cleanups to shape inference (pytorch#653) <anderspapitto> - **[f4acf28](onnx/onnx@f4acf28)**: Remove allowconsumed enforceconsumed from op schema. (pytorch#617) <Ke Zhang> - **[a8e4648](onnx/onnx@a8e4648)**: Adjust link flags when built in Windows Debug mode (pytorch#647) <Yinghai Lu> - **[7c009fe](onnx/onnx@7c009fe)**: Fix lint error in optimizer test (pytorch#656) <bddppq> - **[063d12f](onnx/onnx@063d12f)**: Fix optimizer split pass for models with constant output (pytorch#652) <bddppq>
Add paths for downloaded files to static quantization tutorial
Actual flag is --opt_level and copy pasting the example results in an unrecognized arguments error.
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) [ghstack-poisoned]
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) [ghstack-poisoned]
… (correctly)" Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) [ghstack-poisoned]
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) [ghstack-poisoned]
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) Pull Request resolved: #145020 Approved by: https://github.com/kit1980, https://github.com/albanD
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) Pull Request resolved: #145020 Approved by: https://github.com/kit1980, https://github.com/albanD (cherry picked from commit 0eda02a)
Prevent legacy_load when weights_only=True (correctly) (#145020) Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) Pull Request resolved: #145020 Approved by: https://github.com/kit1980, https://github.com/albanD (cherry picked from commit 0eda02a) Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
* fix headers for gpu instances * remove unused headers --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
Some notes:
This runs 5-10x faster than the old code on my machine; it reads big tensors at ~2.5GB/s, which is faster than the file reading unless the file is cached.
I tried to make this avoid seeks, but
load
takes arbitrary file objects and they may be buffered in python, so when you start reading the storages in C, you need to do a C seek to before the extra Python buffering. If we want to avoid seeks, the only choice is to do the file reading purely in C. I will leave this to a future PR if it's necessary.I change
tensor.__reduce__
to return(t.storage(), t.storageOffset, t.size(), t.stride)
instead of a nested list representation of the tensor. I think this is more consistent with the fact that a tensor is a view on storage (otherwise__reduce__
would break sharing, etc.). One upshot of this is it changes the behavior of copy.copy (which is now just a shallow copy of the tensor view).persistent_id
technically is only allowed to return a string, but we return tuples. This happens to work for the binary protocol, but not the string protocol. It's not guaranteed to work in the future. I'm not sure how to fix it for the module caseload
remains backwards compatible with old checkpoints (although they'll still be slow). There's a test for it.the old serialization code had a little bug I noticed, where it only treated a storage as a view if
offset != 0
. But technically, you could have a view withoffset ==0
, e.g.s2 = s[0:5]
.