8000 Rewrite serialization to correctly handle partial reads/writes in all cases by ezyang · Pull Request #12143 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Rewrite serialization to correctly handle partial reads/writes in all cases #12143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Rewrite serialization to correctly handle partial reads/writes in all…
… cases.

Previously, doRead/doWrite were functions that could return partial reads/writes,
and we checked for this case inconsistently in the call sites of serialization.cpp.
Now, these functions do NOT return the amount of bytes read/written, and instead
handle the necessary checking loop themselves.

Fixes #12042.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
  • Loading branch information
ezyang committed Sep 27, 2018
commit 8495ee6a4ccf50b6e6a2a40a7fa7ce400cf2a465
46 changes: 6 additions & 40 deletions torch/csrc/generic/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#define TH_GENERIC_FILE "generic/serialization.cpp"
#else

#define SYSCHECK(call) { ssize_t __result = call; if (__result < 0) throw std::system_error((int) __result, std::system_category()); }

template <class io>
void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
{
Expand All @@ -16,23 +14,10 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
data = (scalar_t*)cpu_data.get();
THCudaCheck(cudaMemcpy(data, THWStorage_(data)(LIBRARY_STATE self), size * sizeof(scalar_t), cudaMemcpyDeviceToHost));
#endif
ssize_t result = doWrite(fd, &size, sizeof(int64_t));
if (result != sizeof(int64_t))
throw std::system_error(result, std::system_category());
doWrite(fd, &size, sizeof(int64_t));
// fast track for bytes and little endian
if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
char *bytes = (char *) data;
int64_t remaining = sizeof(scalar_t) * size;
while (remaining > 0) {
// we write and read in 1GB blocks to avoid bugs on some OSes
ssize_t result = doWrite(fd, bytes, THMin(remaining, 1073741824));
if (result < 0)
throw std::system_error(result, std::system_category());
bytes += result;
remaining -= result;
}
if (remaining != 0)
throw std::system_error(result, std::system_category());
doWrite(fd, data, sizeof(scalar_t) * size);
} else {
int64_t buffer_size = std::min(size, (int64_t)5000);
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(scalar_t)]);
Expand All @@ -54,7 +39,7 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
}
SYSCHECK(doWrite(fd, le_buffer.get(), to_convert * sizeof(scalar_t)));
doWrite(fd, le_buffer.get(), to_convert * sizeof(scalar_t));
}
}
}
Expand All @@ -67,11 +52,7 @@ THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage)
{
scalar_t *data;
int64_t size;
ssize_t result = doRead(file, &size, sizeof(int64_t));
if (result == 0)
throw std::runtime_error("unexpected EOF. The file might be corrupted.");
if (result != sizeof(int64_t))
throw std::system_error(result, std::system_category());
doRead(file, &size, sizeof(int64_t));
THWStoragePtr storage;
if (_storage == nullptr) {
storage = THWStorage_(newWithSize)(LIBRARY_STATE size);
Expand All @@ -91,28 +72,15 @@ THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage)

// fast track for bytes and little endian
if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
char *bytes = (char *) data;
int64_t remaining = sizeof(scalar_t) * THWStorage_(size)(LIBRARY_STATE storage);
while (remaining > 0) {
// we write and read in 1GB blocks to avoid bugs on some OSes
ssize_t result = doRead(file, bytes, THMin(remaining, 1073741824));
if (result == 0) // 0 means EOF, which is also an error
throw std::runtime_error("unexpected EOF. The file might be corrupted.");
if (result < 0)
throw std::system_error(result, std::system_category());
bytes += result;
remaining -= result;
}
if (remaining != 0)
throw std::system_error(result, std::system_category());
doRead(file, data, sizeof(scalar_t) * THWStorage_(size)(LIBRARY_STATE storage));
} else {
int64_t buffer_size = std::min(size, (int64_t)5000);
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(scalar_t)]);


for (int64_t i = 0; i < size; i += buffer_size) {
size_t to_convert = std::min(size - i, buffer_size);
SYSCHECK(doRead(file, le_buffer.get(), sizeof(scalar_t) * to_convert));
doRead(file, le_buffer.get(), sizeof(scalar_t) * to_convert);

if (sizeof(scalar_t) == 2) {
THP_decodeInt16Buffer((int16_t*)data + i,
Expand Down Expand Up @@ -142,6 +110,4 @@ THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage)
template THWStorage* THPStorage_(readFileRaw<int>)(int fd, THWStorage* storage);
template THWStorage* THPStorage_(readFileRaw<PyObject*>)(PyObject* fd, THWStorage* storage);

#undef SYSCHECK

#endif
141 changes: 102 additions & 39 deletions torch/csrc/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,41 @@
#include "THP.h"
#include "serialization.h"

static ssize_t doPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes);
static ssize_t doPythonReadInto(PyObject* fildes, void* buf, size_t nbytes);
static ssize_t doPythonWrite(PyObject* fildes, void* buf, size_t nbytes);
template <class io>
ssize_t doPartialRead(io fildes, void* buf, size_t nbytes);

template <class io>
ssize_t doPartialWrite(io fildes, void* buf, size_t nbytes);

static ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes);
static ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes);
static ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes);

template <>
ssize_t doRead<int>(int fildes, void* buf, size_t nbytes) {
ssize_t doPartialRead<int>(int fildes, void* buf, size_t nbytes) {
return read(fildes, buf, nbytes);
}

template <>
ssize_t doRead<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
ssize_t doPartialRead<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
// Try to use fildes.readinto() instead of fildes.read()
// because it is more memory efficient.
// TODO: Stop calling PyObject_HasAttrString() in a loop on our read loop

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

auto has_readinto = PyObject_HasAttrString(fildes, "readinto") == 1;
if (has_readinto) {
return doPythonReadInto(fildes, buf, nbytes);
return doPartialPythonReadInto(fildes, buf, nbytes);
}
return doPythonReadBuffered(fildes, buf, nbytes);
return doPartialPythonReadBuffered(fildes, buf, nbytes);
}

template <>
ssize_t doWrite<int>(int fildes, void* buf, size_t nbytes) {
ssize_t doPartialWrite<int>(int fildes, void* buf, size_t nbytes) {
return write(fildes, buf, nbytes);
}

template <>
ssize_t doWrite<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
return doPythonWrite(fildes, buf, nbytes);
ssize_t doPartialWrite<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
return doPartialPythonWrite(fildes, buf, nbytes);
}

static inline bool isUnsupportedOperation() {
Expand All @@ -43,39 +50,39 @@ static inline bool isUnsupportedOperation() {
}

// Call Python fildes.read(nbytes) and copy it to buf.
static inline ssize_t doPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes) {
const size_t buffer_size = 262144; // 2^18
size_t read_bytes = 0;

while (read_bytes < nbytes) {
auto remaining = nbytes - read_bytes;
auto to_read = remaining > buffer_size ? buffer_size : remaining;
THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", to_read));
if (!r) throw python_error();

// read output is String (Python 2) / Bytes (Python 3)
static inline ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t raw_nbytes) {
// If we request a large amount of data, f.read() will internally try to
// allocate a buffer of that size. This is counterproductive, because
// it's not the buffer we ultimately want to write the data into. Read
// less than that and avoid allocating too much extra memory.
// TODO: Maybe 260 KB is a bit small...
const size_t nbytes = std::min<size_t>(raw_nbytes, 262144u); // 2^18 (~260 KB)

THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", nbytes));
if (!r) throw python_error();

// read output is String (Python 2) / Bytes (Python 3)
#if PY_MAJOR_VERSION >= 3
auto size = PyBytes_GET_SIZE(r.get());
const void* bytes = PyBytes_AsString(r.get());
auto size = PyBytes_GET_SIZE(r.get());
const void* py_buf = PyBytes_AsString(r.get());
#else
auto size = PyString_GET_SIZE(r.get());
const void* bytes = PyString_AsString(r.get());
auto size = PyString_GET_SIZE(r.get());
const void* py_buf = PyString_AsString(r.get());
#endif

// we read EOF
if (size == 0) {
return read_bytes;
}
// we read EOF
if (size == 0) {
return 0;
}

memcpy(reinterpret_cast<char*>(buf) + read_bytes, bytes, size);
read_bytes += size;
} // Reading loop
// Slurp it into the buffer we actually want
memcpy(buf, py_buf, size);

return read_bytes;
return size;
}

// Either does fildes.readinto(buf) or fildes.write(buf)
static inline ssize_t doPythonIO(PyObject* fildes, void* buf, size_t nbytes, bool is_read) {
static inline ssize_t doPartialPythonIO(PyObject* fildes, void* buf, size_t nbytes, bool is_read) {
#if PY_MAJOR_VERSION >= 3
auto rw_flag = is_read ? PyBUF_WRITE : PyBUF_READ;
THPObjectPtr memview(PyMemoryView_FromMemory(
Expand All @@ -97,19 +104,75 @@ static inline ssize_t doPythonIO(PyObject* fildes, void* buf, size_t nbytes, boo
// fildes.readinto can return UnsupportedOperation so fall back to fildes.read.
if (is_read && isUnsupportedOperation()) {
PyErr_Clear();
return doPythonReadBuffered(fildes, buf, nbytes);
return doPartialPythonReadBuffered(fildes, buf, nbytes);
}
throw python_error();
}

// Call Python fildes.readinto(buf)
static ssize_t doPythonReadInto(PyObject* fildes, void* buf, size_t nbytes) {
return doPythonIO(fildes, buf, nbytes, /* is_read */ true);
static ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes) {
return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ true);
}

// Call Python fildes.write(buf)
static ssize_t doPythonWrite(PyObject* fildes, void* buf, size_t nbytes) {
return doPythonIO(fildes, buf, nbytes, /* is_read */ false);
static ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes) {
return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ false);
}

// Requires that we read EXACTLY nbytes; fails if we don't.
template <typename io>
void doRead(io fildes, void* raw_buf, size_t nbytes) {
char* buf = static_cast<char*>(raw_buf);
while (nbytes > 0) {
errno = 0; // doPartialRead may not set errno
// we read in 1GB blocks to avoid bugs on some OSes

This comment was marked as off-topic.

This comment was marked as off-topic.

ssize_t r = doPartialRead(fildes, buf, std::min<size_t>(nbytes, 1073741824));
if (r < 0) {
int err = errno;
AT_ASSERTM(err != 0, "read(): impossible! r < 0, but no errno was set");
AT_ASSERTM(err != EAGAIN, "read(): non-blocking fd ", fildes,
" read EAGAIN; cowardly refusing to spin-wait");
if (err == EINTR) {
continue;
} else {
AT_ERROR("read(): fd ", fildes, " failed with ", strerror(err));
}
} else if (r == 0) {
break;
}
buf += r;
// This is guaranteed by POSIX, but I just want to be double-sure
// to not underflow a signed integer.
AT_ASSERT(static_cast<size_t>(r) <= nbytes);
nbytes -= r;
}
if (nbytes != 0) {
AT_ERROR("unexpected EOF, expected ", nbytes, " more bytes. The file might be corrupted.");
}
}

template <typename io>
void doWrite(io fildes, void* raw_buf, size_t nbytes) {
char* buf = static_cast<char*>(raw_buf);
while (nbytes > 0) {
errno = 0; // doPartialWrite may not set errno
// we write in 1GB blocks to avoid bugs on some OSes
ssize_t r = doPartialWrite(fildes, buf, std::min<size_t>(nbytes, 1073741824));
if (r < 0) {
int err = errno;
AT_ASSERTM(err != 0, "write(): impossible! r < 0, but no errno was set");
AT_ASSERTM(err != EAGAIN, "write(): non-blocking fd ", fildes,
" read EAGAIN; cowardly refusing to spin-wait");
if (err == EINTR) {
continue;
} else {
AT_ERROR("write(): fd ", fildes, " failed with ", strerror(err));
}
}
buf += r;
AT_ASSERT(static_cast<size_t>(r) <= nbytes);
nbytes -= r;
}
}

#include "generic/serialization.cpp"
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
#include < 51A7 TH/THGenerateHalfType.h>

template <class io>
ssize_t doRead(io fildes, void* buf, size_t nbytes);
void doRead(io fildes, void* buf, size_t nbytes);

template <class io>
ssize_t doWrite(io fildes, void* buf, size_t nbytes);
void doWrite(io fildes, void* buf, size_t nbytes);

#endif
0