8000 improved serialization (no tar copy) (#713) · pytorch/pytorch@e71cf20 · GitHub
[go: up one dir, main page]

Skip to content

Commit e71cf20

Browse files
adamlererapaszke
authored andcommitted
improved serialization (no tar copy) (#713)
1 parent adb4cb2 commit e71cf20

File tree

9 files changed

+322
-178
lines changed

9 files changed

+322
-178
lines changed

test/common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import os
23
import argparse
34
import unittest
45
import contextlib
@@ -212,3 +213,25 @@ def get_numerical_jacobian(fn, input, target):
212213
d_tensor[i] = outb
213214

214215
return jacobian
216+
217+
218+
def download_file(url, path, binary=True):
219+
if sys.version_info < (3,):
220+
import urllib2
221+
request = urllib2
222+
error = urllib2
223+
else:
224+
import urllib.request
225+
import urllib.error
226+
request = urllib.request
227+
error = urllib.error
228+
229+
if os.path.exists(path):
230+
return True
231+
try:
232+
data = request.urlopen(url, timeout=15).read()
233+
with open(path, 'wb' if binary else 'w') as f:
234+
f.write(data)
235+
return True
236+
except error.URLError as e:
237+
return False

test/test_torch.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import os
23
import math
34
import random
45
import torch
@@ -8,7 +9,7 @@
89
import warnings
910
from itertools import product, chain
1011
from functools import wraps
11-
from common import TestCase, iter_indices, TEST_NUMPY, run_tests
12+
from common import TestCase, iter_indices, TEST_NUMPY, run_tests, download_file
1213

1314
if TEST_NUMPY:
1415
import numpy as np
@@ -2449,7 +2450,9 @@ def test_copy(self):
24492450
a_clone = a.clone()
24502451
b = copy(a)
24512452
b.fill_(1)
2452-
self.assertEqual(a, a_clone)
2453+
# copy is a shallow copy, only copies the tensor view,
2454+
# not the data
2455+
self.assertEqual(a, b)
24532456

24542457
def test_pickle(self):
24552458
if sys.version_info[0] == 2:
@@ -2521,6 +2524,11 @@ def test_serialization(self):
25212524
b = [a[i % 2] for i in range(4)]
25222525
b += [a[0].storage()]
25232526
b += [a[0].storage()[1:4]]
2527+
b += [torch.range(1, 10).int()]
2528+
t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
2529+
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
2530+
b += [(t1.storage(), t1.storage(), t2.storage())]
2531+
b += [a[0].storage()[0:2]]
25242532
for use_name in (False, True):
25252533
with tempfile.NamedTemporaryFile() as f:
25262534
handle = f if not use_name else f.name
@@ -2540,6 +2548,60 @@ def test_serialization(self):
25402548
self.assertEqual(c[1], c[3], 0)
25412549
self.assertEqual(c[4], c[5][1:4], 0)
25422550

2551+
# check that serializing the same storage view object unpickles
2552+
# it as one object not two (and vice versa)
2553+
views = c[7]
2554+
self.assertEqual(views[0]._cdata, views[1]._cdata)
2555+
self.assertEqual(views[0], views[2])
2556+
self.assertNotEqual(views[0]._cdata, views[2]._cdata)
2557+
2558+
rootview = c[8]
2559+
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
2560+
2561+
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
2562+
def test_serialization_cuda(self):
2563+
device_count = torch.cuda.device_count()
2564+
t0 = torch.cuda.FloatTensor(5).fill_(1)
2565+
torch.cuda.set_device(device_count - 1)
2566+
tn = torch.cuda.FloatTensor(3).fill_(2)
2567+
torch.cuda.set_device(0)
2568+
b = (t0, tn)
2569+
with tempfile.NamedTemporaryFile() as f:
2570+
torch.save(b, f)
2571+
f.seek(0)
2572+
c = torch.load(f)
2573+
self.assertEqual(b, c, 0)
2574+
u0, un = c
2575+
self.assertEqual(u0.get_device(), 0)
2576+
self.assertEqual(un.get_device(), device_count - 1)
2577+
2578+
def test_serialization_backwards_compat(self):
2579+
a = [torch.range(1 + i, 25 + i).view(5, 5).float() for i in range(2)]
2580+
b = [a[i % 2] for i in range(4)]
2581+
b += [a[0].storage()]
2582+
b += [a[0].storage()[1:4]]
2583+
DATA_URL = 'https://s3.amazonaws.com/pytorch/legacy_serialized.pt'
2584+
data_dir = os.path.join(os.path.dirname(__file__), 'data')
2585+
test_file_path = os.path.join(data_dir, 'legacy_serialized.pt')
2586+
succ = download_file(DATA_URL, test_file_path)
2587+
if not succ:
2588+
warnings.warn(("Couldn't download the test file for backwards compatibility! "
2589+
"Tests will be incomplete!"), RuntimeWarning)
2590+
return
2591+
c = torch.load(test_file_path)
2592+
self.assertEqual(b, c, 0)
2593+
self.assertTrue(isinstance(c[0], torch.FloatTensor))
2594+
self.assertTrue(isinstance(c[1], torch.FloatTensor))
2595+
self.assertTrue(isinstance(c[2], torch.FloatTensor))
2596+
self.assertTrue(isinstance(c[3], torch.FloatTensor))
2597+
self.assertTrue(isinstance(c[4], torch.FloatStorage))
2598+
c[0].fill_(10)
2599+
self.assertEqual(c[0], c[2], 0)
2600+
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
2601+
c[1].fill_(20)
2602+
self.assertEqual(c[1], c[3], 0)
2603+
self.assertEqual(c[4], c[5][1:4], 0)
2604+
25432605
def test_serialization_container(self):
25442606
def import_module(name, filename):
25452607
if sys.version_info >= (3, 5):

test/test_utils.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
HAS_CUDA = torch.cuda.is_available()
2121

22-
from common import TestCase, run_tests
22+
from common import TestCase, run_tests, download_file
2323

2424
try:
2525
import cffi
@@ -296,35 +296,13 @@ def do_test(self):
296296
self.assertEqual(grad_input, test['grad_input'])
297297
return do_test
298298

299-
@classmethod
300-
def _download_data(cls, test_file_path):
301-
if os.path.exists(test_file_path):
302-
return
303-
print('Downloading test file for TestLuaReader.')
304-
DATA_URL = 'https://s3.amazonaws.com/pytorch/legacy_modules.t7'
305-
urllib = cls._get_urllib('request')
306-
data = urllib.urlopen(DATA_URL, timeout=15).read()
307-
with open(test_file_path, 'wb') as f:
308-
f.write(data)
309-
310-
@staticmethod
311-
def _get_urllib(submodule):
312-
if sys.version_info < (3,):
313-
import urllib2
314-
return urllib2
315-
else:
316-
import urllib.error
317-
import urllib.request
318-
return getattr(urllib, submodule)
319-
320299
@classmethod
321300
def init(cls):
301+
DATA_URL = 'https://s3.amazonaws.com/pytorch/legacy_modules.t7'
322302
data_dir = os.path.join(os.path.dirname(__file__), 'data')
323303
test_file_path = os.path.join(data_dir, 'legacy_modules.t7')
324-
urllib = cls._get_urllib('error')
325-
try:
326-
cls._download_data(test_file_path)
327-
except urllib.URLError as e:
304+
succ = download_file(DATA_URL, test_file_path)
305+
if not succ:
328306
warnings.warn(("Couldn't download the test file for TestLuaReader! "
329307
"Tests will be incomplete!"), RuntimeWarning)
330308
return

torch/csrc/generic/StorageMethods.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,36 @@ PyObject * THPStorage_(newWithFile)(PyObject *_unused, PyObject *file)
190190
int fd = PyObject_AsFileDescriptor(file);
191191
THPUtils_assert(fd != -1, "_new_with_file couldn't retrieve a file "
192192
"descriptor from given object");
193-
THStoragePtr storage = THPStorage_(readFileRaw)(fd);
193+
THStorage *storage = THPStorage_(readFileRaw)(fd, nullptr);
194+
if (storage == nullptr)
195+
return nullptr;
194196
PyObject *result = THPStorage_(New)(storage);
195-
storage.release();
196197
return result;
197198
END_HANDLE_TH_ERRORS
198199
}
199200

201+
static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args)
202+
{
203+
HANDLE_TH_ERRORS
204+
PyObject *file = PyTuple_GET_ITEM(args, 0);
205+
int fd = PyObject_AsFileDescriptor(file);
206+
207+
PyObject *offset = PyTuple_GET_ITEM(args, 1);
208+
if (offset != Py_None) {
209+
lseek(fd, THPUtils_unpackLong(offset), SEEK_SET);
210+
}
211+
212+
THPUtils_assert(fd != -1, "_set_from_file couldn't retrieve a file "
213+
"descriptor from given object");
214+
THStorage *storage = THPStorage_(readFileRaw)(fd, self->cdata);
215+
if (storage == nullptr)
216+
return nullptr;
217+
Py_INCREF(self);
218+
219+
return (PyObject *) self;
220+
END_HANDLE_TH_ERRORS
221+
}
222+
200223
#ifdef THC_GENERIC_FILE
201224
PyObject * THPStorage_(getDevice)(THPStorage *self)
202225
{
@@ -250,6 +273,7 @@ static PyMethodDef THPStorage_(methods)[] = {
250273
{"is_pinned", (PyCFunction)THPStorage_(isPinned), METH_NOARGS, NULL},
251274
{"_write_file", (PyCFunction)THPStorage_(writeFile), METH_O, NULL},
252275
{"_new_with_file", (PyCFunction)THPStorage_(newWithFile), METH_O | METH_STATIC, NULL},
276+
{"_set_from_file", (PyCFunction)THPStorage_(setFromFile), METH_VARARGS, NULL},
253277
#ifndef THC_GENERIC_FILE
254278
{"from_buffer", (PyCFunction)THPStorage_(fromBuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
255279
#endif

torch/csrc/generic/serialization.cpp

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,35 @@ THTensor * THPTensor_(newWithMetadataFileRaw)(int fd, THStorage *storage)
2929
void THPStorage_(writeFileRaw)(THStorage *self, int fd)
3030
{
3131
real *data;
32+
int64_t size = self->size;
3233
#ifndef THC_GENERIC_FILE
3334
data = self->data;
3435
#else
35-
std::unique_ptr<char[]> cpu_data(new char[self->size * sizeof(real)]);
36+
std::unique_ptr<char[]> cpu_data(new char[size * sizeof(real)]);
3637
data = (real*)cpu_data.get();
37-
THCudaCheck(cudaMemcpy(data, self->data, self->size * sizeof(real), cudaMemcpyDeviceToHost));
38+
THCudaCheck(cudaMemcpy(data, self->data, size * sizeof(real), cudaMemcpyDeviceToHost));
3839
#endif
39-
SYSCHECK(write(fd, &self->size, sizeof(long)));
40+
ssize_t result = write(fd, &size, sizeof(int64_t));
41+
if (result != sizeof(int64_t))
42+
throw std::system_error(result, std::system_category());
4043
// fast track for bytes and little endian
4144
if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
4245
char *bytes = (char *) data;
43-
uint64_t remaining = sizeof(real) * self->size;
46+
int64_t remaining = sizeof(real) * size;
4447
while (remaining > 0) {
4548
ssize_t result = write(fd, bytes, remaining);
4649
if (result < 0)
4750
throw std::system_error(result, std::system_category());
4851
bytes += result;
4952
remaining -= result;
5053
}
54+
if (remaining != 0)
55+
throw std::system_error(result, std::system_category());
5156
} else {
52-
int64_t buffer_size = std::min(self->size, (long)5000);
57+
int64_t buffer_size = std::min(size, (int64_t)5000);
5358
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
54-
for (int64_t i = 0; i < self->size; i += buffer_size) {
55-
size_t to_convert = std::min(self->size - i, buffer_size);
59+
for (int64_t i = 0; i < size; i += buffer_size) {
60+
size_t to_convert = std::min(size - i, buffer_size);
5661
if (sizeof(real) == 2) {
5762
THP_encodeInt16Buffer((uint8_t*)le_buffer.get(),
5863
(const int16_t*)data + i,
@@ -74,12 +79,22 @@ void THPStorage_(writeFileRaw)(THStorage *self, int fd)
7479
}
7580
}
7681

77-
THStorage * THPStorage_(readFileRaw)(int fd)
82+
THStorage * THPStorage_(readFileRaw)(int fd, THStorage *_storage)
7883
{
7984
real *data;
80-
long size;
81-
SYSCHECK(read(fd, &size, sizeof(long)));
82-
THStoragePtr storage = THStorage_(newWithSize)(LIBRARY_STATE size);
85+
int64_t size;
86+
ssize_t result = read(fd, &size, sizeof(int64_t));
87+
if (result != sizeof(int64_t))
88+
throw std::system_error(result, std::system_category());
89+
THStoragePtr storage;
90+
if (_storage == nullptr) {
91+
storage = THStorage_(newWithSize)(LIBRARY_STATE size);
92+
} else {
93+
THPUtils_assert(_storage->size == size,
94+
"storage has wrong size: expected %ld got %ld",
95+
size, _storage->size);
96+
storage = _storage;
97+
}
8398

8499
#ifndef THC_GENERIC_FILE
85100
data = storage->data;
@@ -91,16 +106,18 @@ THStorage * THPStorage_(readFileRaw)(int fd)
91106
// fast track for bytes and little endian
92107
if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
93108
char *bytes = (char *) data;
94-
uint64_t remaining = sizeof(real) * storage->size;
109+
int64_t remaining = sizeof(real) * storage->size;
95110
while (remaining > 0) {
96111
ssize_t result = read(fd, bytes, remaining);
97-
if (result < 0)
112+
if (result <= 0) // 0 means EOF, which is also an error
98113
throw std::system_error(result, std::system_category());
99114
bytes += result;
100115
remaining -= result;
101116
}
117+
if (remaining != 0)
118+
throw std::system_error(result, std::system_category());
102119
} else {
103-
int64_t buffer_size = std::min(size, (long)5000);
120+
int64_t buffer_size = std::min(size, (int64_t)5000);
104121
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
105122
for (int64_t i = 0; i < size; i += buffer_size) {
106123
size_t to_convert = std::min(size - i, buffer_size);

torch/csrc/generic/serialization.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
void THPTensor_(writeMetadataRaw)(THTensor *self, int fd);
66
THTensor * THPTensor_(newWithMetadataFileRaw)(int fd, THStorage *storage);
77
void THPStorage_(writeFileRaw)(THStorage *self, int fd);
8-
THStorage * THPStorage_(readFileRaw)(int fd);
8+
THStorage * THPStorage_(readFileRaw)(int fd, THStorage *storage);
99

1010
#endif

torch/multiprocessing/reductions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def rebuild_storage_filename(cls, manager, handle, size):
8888
return storage._shared_decref()
8989

9090

91-
def reubild_storage_cuda(cls, device, handle, size, offset, view_size):
91+
def rebuild_storage_cuda(cls, device, handle, size, offset, view_size):
9292
storage = storage_from_cache(cls, handle)
9393
if storage is not None:
9494
return storage._new_view(offset, view_size)
@@ -103,7 +103,7 @@ def reduce_storage(storage):
103103
if storage.is_cuda:
104104
metadata = storage._share_cuda_()
105105
cache_key = metadata[1]
106-
rebuild = reubild_storage_cuda
106+
rebuild = rebuild_storage_cuda
107107
elif get_sharing_strategy() == 'file_system':
108108
metadata = storage._share_filename_()
109109
cache_key = metadata[1]

0 commit comments

Comments
 (0)
0