8000 improved serialization (no tar copy) by adamlerer · Pull Request #713 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

improved serialization (no tar copy) #713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 22, 2017
Merged

Conversation

adamlerer
Copy link
Contributor
@adamlerer adamlerer commented Feb 9, 2017

Some notes:

  • This runs 5-10x faster than the old code on my machine; it reads big tensors at ~2.5GB/s, which is faster than the file reading unless the file is cached.

  • I tried to make this avoid seeks, but load takes arbitrary file objects and they may be buffered in python, so when you start reading the storages in C, you need to do a C seek to before the extra Python buffering. If we want to avoid seeks, the only choice is to do the file reading purely in C. I will leave this to a future PR if it's necessary.

  • I change tensor.__reduce__ to return (t.storage(), t.storageOffset, t.size(), t.stride) instead of a nested list representation of the tensor. I think this is more consistent with the fact that a tensor is a view on storage (otherwise __reduce__ would break sharing, etc.). One upshot of this is it changes the behavior of copy.copy (which is now just a shallow copy of the tensor view).

  • persistent_id technically is only allowed to return a string, but we return tuples. This happens to work for the binary protocol, but not the string protocol. It's not guaranteed to work in the future. I'm not sure how to fix it for the module case

  • load remains backwards compatible with old checkpoints (although they'll still be slow). There's a test for it.

  • the old serialization code had a little bug I noticed, where it only treated a storage as a view if offset != 0. But technically, you could have a view with offset ==0, e.g. s2 = s[0:5].

@adamlerer adamlerer force-pushed the serialization branch 3 times, most recently from cdc97f8 to 7b4ec7f Compare February 10, 2017 21:57
@@ -2497,6 +2500,7 @@ def test_serialization(self):
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
b + [torch.range(1, 10).int()]

This comment was marked as off-topic.

@@ -190,13 +190,33 @@ PyObject * THPStorage_(newWithFile)(PyObject *_unused, PyObject *file)
int fd = PyObject_AsFileDescriptor(file);
THPUtils_assert(fd != -1, "_new_with_file couldn't retrieve a file "
"descriptor from given object");
THStoragePtr storage = THPStorage_(readFileRaw)(fd);
THStoragePtr storage = THPStorage_(readFileRaw)(fd, nullptr);
PyObject *result = THPStorage_(New)(storage);
storage.release();

This comment was marked as off-topic.

PyObject *file = PyTuple_GET_ITEM(args, 0);
int fd = PyObject_AsFileDescriptor(file);

PyObject *offset = PyTuple_GET_ITEM(args, 1);

This comment was marked as off-topic.

serialized_storages[root_key] = root
is_view = obj._cdata != root._cdata

return ('storage', storage_type, root_key, location, root.size(), is_view, offset, obj.size())

This comment was marked as off-topic.

This comment was marked as off-topic.

else:
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])

# try the legacy loader first, which only works if f is a tarfile

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

if magic_number != MAGIC_NUMBER:
raise RuntimeError("Invalid magic number; corrupt file?")
protocol_version = pickle_module.load(f)
if protocol_version != PROTOCOL_VERSION:

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

data_type(size), location)
data = deserialized_storages[key]
if is_view:
data = data[offset:offset + view_size]

This comment was marked as off-topic.

This comment was marked as off-topic.

if key not in deserialized_storages:
deserialized_storages[key] = restore_location(
data_type(size), location)
data = deserialized_storages[key]

This comment was marked as off-topic.


def persistent_id(obj):
# FIXME: the docs say that persistent_id should only return a string

This comment was marked as off-topic.

if magic_number != MAGIC_NUMBER:
raise RuntimeError("Invalid magic number; corrupt file?")
protocol_version = pickle_module.load(f)
if protocol_version != PROTOCOL_VERSION:

This comment was marked as off-topic.

b += [torch.range(1, 10).int()]
t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
b += [(t1.storage(), t1.storage(), t2.storage())]

This comment was marked as off-topic.

views = c[7]
self.assertEqual(views[0]._cdata, views[1]._cdata)
self.assertEqual(views[0], views[2])
self.assertNotEqual(views[0]._cdata, views[2]._cdata)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

8000 torch/serialization.py Outdated
_check_container_source(*data)
return data[0]
elif typename == 'storage':
data_type, key, location, size, is_view, view_metadata = data

This comment was marked as off-topic.

deserialized_objects[key] = restore_location(
data_type(size), location)
storage = deserialized_objects[key]
if is_view:

This comment was marked as off-topic.

torch/tensor.py Outdated
tuple(self.size()),
self.stride()))

def __setstate__(self, state):

This comment was marked as off-topic.

This comment was marked as off-topic.

@adamlerer adamlerer force-pushed the serialization branch 2 times, most recently from f63675b to 8b22927 Compare February 13, 2017 06:36
@soumith soumith closed this Feb 15, 2017
@soumith soumith reopened this Feb 15, 2017
@adamlerer
Copy link
Contributor Author

@soumith please don't merge yet.

@adamlerer adamlerer force-pushed the serialization branch 3 times, most recently from ef93662 to 18c778a Compare February 15, 2017 19:20
@adamlerer
Copy link
Contributor Author

Okay, I think it's ready for merge now.

P.S. as of this change, .pt files are no longer compressed, I've seen this increase their size by 2x in cases where the data is compressible (e.g. LongTensor with small ints).

@apaszke
Copy link
Contributor
apaszke commented Feb 15, 2017

Huh? I don't think they're compressed now, so the file sizes should be quite similar.

@adamlerer
Copy link
Contributor Author

@apaszke you're right, my old checkpoints were 1/2 the size because of #717 , not because of the checkpoints. Thanks.

All ready for merge.

} else {
int64_t buffer_size = std::min(self->size, (long)5000);
long buffer_size = std::min(size, (long)5000);

This comment was marked as off-topic.

views = c[7]
self.assertEqual(views[0]._cdata, views[1]._cdata)
self.assertEqual(views[0], views[2])
self.assertNotEqual(views[0]._cdata, views[2]._cdata)

This comment was marked as off-topic.

while (remaining > 0) {
ssize_t result = write(fd, bytes, remaining);
if (result < 0)
throw std::system_error(result, std::system_category());
bytes += result;
remaining -= result;
}
if (remaining != 0)

This comment was marked as off-topic.

} else {
int64_t buffer_size = std::min(self->size, (long)5000);
int64_t buffer_size = std::min(size, (long)5000);

This comment was marked as off-topic.

throw std::system_error(result, std::system_category());
bytes += result;
remaining -= result;
}
if (remaining != 0)

This comment was marked as off-topic.

colesbury pushed a commit to colesbury/pytorch that referenced this pull request Feb 28, 2017
fix indexing bug in sampleMultinomialOnce
bddppq pushed a commit to bddppq/pytorch that referenced this pull request Apr 17, 2018
…9c90c8

Previous import was a4dcc47791eb127652f5aaddd51d8896d446a067

Included changes:
- **[985af3f](onnx/onnx@985af3f)**: Update PythonAPIOverview.md (pytorch#738) <Dmytro Dzhulgakov>
- **[b69be33](onnx/onnx@b69be33)**: Add backend test for upsample (pytorch#729) <Sebastian Meßmer>
- **[0d9496e](onnx/onnx@0d9496e)**: Input test data of concat op should be float (pytorch#711) <Changming Sun>
- **[20bcb8b](onnx/onnx@20bcb8b)**: Fix the spec for batchnorm and instancenorm (pytorch#733) <Lu Fang>
- **[c9f825f](onnx/onnx@c9f825f)**: Refine a little bit about op spec. (pytorch#666) <Ke Zhang>
- **[a484eb2](onnx/onnx@a484eb2)**: Fix an error in Conv doc (pytorch#731) <Lu Fang>
- **[7410cc4](onnx/onnx@7410cc4)**: Fix incorrect package output paths (pytorch#730) <bddppq>
- **[be546e2](onnx/onnx@be546e2)**: Improve optimizer's API and docs (pytorch#713) <Lu Fang>
- **[c61506f](onnx/onnx@c61506f)**: Fix the shape inference python API (pytorch#716) <Lu Fang>
- **[e9d4134](onnx/onnx@e9d4134)**: Fix cmake on windows when not building python extension (pytorch#728) <bddppq>
- **[72187aa](onnx/onnx@72187aa)**: Add value_info support in make_graph (pytorch#726) <Lu Fang>
- **[67b7d89](onnx/onnx@67b7d89)**: Fix gen_proto in cmake (pytorch#719) <bddppq>
- **[fcb4ae3](onnx/onnx@fcb4ae3)**: docs rewording: Important Python Functions -> Python API Overview (pytorch#721) <anderspapitto>
- **[24275d6](onnx/onnx@24275d6)**: Ignore .eggs directory when doing lint (pytorch#722) <bddppq>
- **[54be8fa](onnx/onnx@54be8fa)**: Use cmake3 if it's available (pytorch#718) <bddppq>
- **[b8c4238](onnx/onnx@b8c4238)**: Add python function docs (pytorch#714) <Lu Fang>
- **[e177493](onnx/onnx@e177493)**: Remove unused cmake utils (pytorch#712) <bddppq>
- **[72d6ad6](onnx/onnx@72d6ad6)**: Remove pycmd from CMake (pytorch#710) <bddppq>
- **[93f0d40](onnx/onnx@93f0d40)**: Fix windows local build (pytorch#709) <Raymond Yang>
- **[6734224](onnx/onnx@6734224)**: CMake fixes and setup.py cleanup (pytorch#706) <bddppq>
- **[7f6a4fd](onnx/onnx@7f6a4fd)**: Add docs to explain important functions in ONNX Infra (pytorch#682) <Lu Fang>
- **[f0f6b3d](onnx/onnx@f0f6b3d)**: fix hardmax test cases make output dtype same as input (pytorch#705) <Wenhao Hu>
- **[c970f0c](onnx/onnx@c970f0c)**: Fix the Dummy backend (pytorch#701) <Lu Fang>
- **[2af45df](onnx/onnx@2af45df)**: setup.py uses cmake build system (pytorch#606) <anderspapitto>
- **[dfcaade](onnx/onnx@dfcaade)**: clean up unused variable left by removing consumed_input (pytorch#697) <bddppq>
- **[accfc74](onnx/onnx@accfc74)**: Remove incorrect backend test (pytorch#700) <Lu Fang>
- **[e558732](onnx/onnx@e558732)**: add max inclusive version to defs.get_schema function (pytorch#695) <Wenhao Hu>
- **[16f02eb](onnx/onnx@16f02eb)**: add API to add domain to min/max version for extension. (pytorch#694) <Ke Zhang>
- **[3e560dd](onnx/onnx@3e560dd)**: Fix doc for initializer (pytorch#690) <bddppq>
- **[6cc4f53](onnx/onnx@6cc4f53)**: Add model save function (pytorch#692) <Lu Fang>
- **[21eaf9b](onnx/onnx@21eaf9b)**: Changing the string discussing versions in operator specifications. (pytorch#691) <Niklas Gustafsson>
- **[3b0cdf4](onnx/onnx@3b0cdf4)**: Minor code quality improvements in optimizer/ (pytorch#612) <Sebastian Meßmer>
- **[641f126](onnx/onnx@641f126)**: Fix Gemm doc wording (pytorch#689) <bddppq>
- **[4a0ec75](onnx/onnx@4a0ec75)**: Clarifies installation error message when external protobuf dependencies are missing (pytorch#684) <Daniel J. H>
- **[960a2c3](onnx/onnx@960a2c3)**: Check outputs dtype in backend tests (pytorch#567) <bddppq>
- **[1d7dee4](onnx/onnx@1d7dee4)**: Fix Average pool test cases converted from PyTorch (pytorch#677) <Lu Fang>
- **[36d7fff](onnx/onnx@36d7fff)**: Fix Attribute default value pybind11 binding (pytorch#671) <bddppq>
- **[0536866](onnx/onnx@0536866)**: git ignore .pytest_cache (pytorch#674) <bddppq>
- **[afc84ac](onnx/onnx@afc84ac)**: Update README.md (pytorch#672) <Dmytro Dzhulgakov>
- **[9d2b530](onnx/onnx@9d2b530)**: Revert "[Typing 1/3] Setup mypy type checker (pytorch#607)" (pytorch#667) <bddppq>
- **[086727e](onnx/onnx@086727e)**: [Typing 1/3] Setup mypy type checker (pytorch#607) <Sebastian Meßmer>
- **[5716e20](onnx/onnx@5716e20)**: Convert all Node tests to Model tests (pytorch#651) <bddppq>
- **[6fe932a](onnx/onnx@6fe932a)**: Replace unittest.skip with custom exception (pytorch#659) <Dmytro Dzhulgakov>
- **[ecac1c1](onnx/onnx@ecac1c1)**: Merge Rel 1.1.0 branch into master (pytorch#657) <Anirudh>
- **[5cb999d](onnx/onnx@5cb999d)**: Minor cleanups to shape inference (pytorch#653) <anderspapitto>
- **[f4acf28](onnx/onnx@f4acf28)**: Remove allowconsumed enforceconsumed from op schema. (pytorch#617) <Ke Zhang>
- **[a8e4648](onnx/onnx@a8e4648)**: Adjust link flags when built in Windows Debug mode (pytorch#647) <Yinghai Lu>
- **[7c009fe](onnx/onnx@7c009fe)**: Fix lint error in optimizer test (pytorch#656) <bddppq>
- **[063d12f](onnx/onnx@063d12f)**: Fix optimizer split pass for models with constant output (pytorch#652) <bddppq>
mrshenli pushed a commit to mrshenli/pytorch that referenced this pull request Apr 11, 2020
Add paths for downloaded files to static quantization tutorial
mcarilli pushed a commit to mcarilli/pytorch that referenced this pull request Mar 18, 2021
hubertlu-tw pushed a commit to hubertlu-tw/pytorch that referenced this pull request Nov 1, 2022
Actual flag is --opt_level and copy pasting the example results in an unrecognized arguments error.
mikaylagawarecki added a commit that referenced this pull request Jan 17, 2025
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False)


Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405)

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Jan 17, 2025
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False)


Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405)

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Jan 17, 2025
… (correctly)"

Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False)


Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405)

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Jan 17, 2025
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False)


Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jan 17, 2025
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False)

Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405)
Pull Request resolved: #145020
Approved by: https://github.com/kit1980, https://github.com/albanD
pytorchbot pushed a commit that referenced this pull request Jan 17, 2025
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False)

Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405)
Pull Request resolved: #145020
Approved by: https://github.com/kit1980, https://github.com/albanD

(cherry picked from commit 0eda02a)
kit1980 pushed a commit that referenced this pull request Jan 17, 2025
Prevent legacy_load when weights_only=True (correctly) (#145020)

Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False)

Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405)
Pull Request resolved: #145020
Approved by: https://github.com/kit1980, https://github.com/albanD

(cherry picked from commit 0eda02a)

Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
akashveramd pushed a commit to akashveramd/pytorch that referenced this pull request Apr 9, 2025
* fix headers for gpu instances

* remove unused headers

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0