8000 more careful error checking in C serialization.. · pytorch/pytorch@a1c2819 · GitHub
[go: up one dir, main page]

Skip to content

Commit a1c2819

Browse files
committed
more careful error checking in C serialization..
1 parent 9eb81e2 commit a1c2819

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

torch/csrc/generic/StorageMethods.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ PyObject * THPStorage_(newWithFile)(PyObject *_unused, PyObject *file)
191191
THPUtils_assert(fd != -1, "_new_with_file couldn't retrieve a file "
192192
"descriptor from given object");
193193
THStorage *storage = THPStorage_(readFileRaw)(fd, nullptr);
194+
if (storage == nullptr)
195+
return nullptr;
194196
PyObject *result = THPStorage_(New)(storage);
195197
return result;
196198
END_HANDLE_TH_ERRORS
@@ -209,7 +211,9 @@ static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args)
209211

210212
THPUtils_assert(fd != -1, "_set_from_file couldn't retrieve a file "
211213
"descriptor from given object");
212-
THPStorage_(readFileRaw)(fd, self->cdata);
214+
THStorage *storage = THPStorage_(readFileRaw)(fd, self->cdata);
215+
if (storage == nullptr)
216+
return nullptr;
213217
Py_INCREF(self);
214218

215219
return (PyObject *) self;

torch/csrc/generic/serialization.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,35 @@ THTensor * THPTensor_(newWithMetadataFileRaw)(int fd, THStorage *storage)
2929
void THPStorage_(writeFileRaw)(THStorage *self, int fd)
3030
{
3131
real *data;
32+
int64_t size = self->size;
3233
#ifndef THC_GENERIC_FILE
3334
data = self->data;
3435
#else
35-
std::unique_ptr<char[]> cpu_data(new char[self->size * sizeof(real)]);
36+
std::unique_ptr<char[]> cpu_data(new char[size * sizeof(real)]);
3637
data = (real*)cpu_data.get();
37-
THCudaCheck(cudaMemcpy(data, self->data, self->size * sizeof(real), cudaMemcpyDeviceToHost));
38+
THCudaCheck(cudaMemcpy(data, self->data, size * sizeof(real), cudaMemcpyDeviceToHost));
3839
#endif
39-
SYSCHECK(write(fd, &self->size, sizeof(long)));
40+
ssize_t result = write(fd, &size, sizeof(int64_t));
41+
if (result != sizeof(int64_t))
42+
throw std::system_error(result, std::system_category());
4043
// fast track for bytes and little endian
4144
if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
4245
char *bytes = (char *) data;
43-
uint64_t remaining = sizeof(real) * self->size;
46+
int64_t remaining = sizeof(real) * size;
4447
while (remaining > 0) {
4548
ssize_t result = write(fd, bytes, remaining);
4649
if (result < 0)
4750
throw std::system_error(result, std::system_category());
4851
bytes += result;
4952
remaining -= result;
5053
}
54+
if (remaining != 0)
55+
throw std::system_error(result, std::system_category());
5156
} else {
52-
int64_t buffer_size = std::min(self->size, (long)5000);
57+
int64_t buffer_size = std::min(size, (long)5000);
5358
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
54-
for (int64_t i = 0; i < self->size; i += buffer_size) {
55-
size_t to_convert = std::min(self->size - i, buffer_size);
59+
for (int64_t i = 0; i < size; i += buffer_size) {
60+
size_t to_convert = std::min(size - i, buffer_size);
5661
if (sizeof(real) == 2) {
5762
THP_encodeInt16Buffer((uint8_t*)le_buffer.get(),
5863
(const int16_t*)data + i,
@@ -77,14 +82,17 @@ void THPStorage_(writeFileRaw)(THStorage *self, int fd)
7782
THStorage * THPStorage_(readFileRaw)(int fd, THStorage *_storage)
7883
{
7984
real *data;
80-
long size;
81-
SYSCHECK(read(fd, &size, sizeof(long)));
82-
85+
int64_t size;
86+
ssize_t result = read(fd, &size, sizeof(int64_t));
87+
if (result != sizeof(int64_t))
88+
throw std::system_error(result, std::system_category());
8389
THStoragePtr storage;
8490
if (_storage == nullptr) {
8591
storage = THStorage_(newWithSize)(LIBRARY_STATE size);
8692
} else {
87-
THPUtils_assert(_storage->size == size, "storage has wrong size");
93+
THPUtils_assert(_storage->size == size,
94+
"storage has wrong size: expected %ld got %ld",
95+
size, _storage->size);
8896
storage = _storage;
8997
}
9098

@@ -98,14 +106,16 @@ THStorage * THPStorage_(readFileRaw)(int fd, THStorage *_storage)
98106
// fast track for bytes and little endian
99107
if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
100108
char *bytes = (char *) data;
101-
uint64_t remaining = sizeof(real) * storage->size;
109+
int64_t remaining = sizeof(real) * storage->size;
102110
while (remaining > 0) {
103111
ssize_t result = read(fd, bytes, remaining);
104-
if (result < 0)
112+
if (result <= 0) // 0 means EOF, which is also an error
105113
throw std::system_error(result, std::system_category());
106114
bytes += result;
107115
remaining -= result;
108116
}
117+
if (remaining != 0)
118+
throw std::system_error(result, std::system_category());
109119
} else {
110120
int64_t buffer_size = std::min(size, (long)5000);
111121
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);

torch/serialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
111111
pickle_protocol: can be specified to override the default protocol
112112
"""
113113
new_fd = False
114-
if isinstance(f, str):
114+
if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):
115115
new_fd = True
116116
f = open(f, "wb")
117117
try:
@@ -213,7 +213,7 @@ def load(f, map_location=None, pickle_module=pickle):
213213
the pickle_module used to serialize file)
214214
"""
215215
new_fd = False
216-
if isinstance(f, str):
216+
if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)):
217217
new_fd = True
218218
f = open(f, 'rb')
219219
try:

0 commit comments

Comments
 (0)
0