8000 bpo-46805: Add low level UDP socket functions to asyncio by agronholm · Pull Request #31455 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

bpo-46805: Add low level UDP socket functions to asyncio #31455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed recvfrom_into()
  • Loading branch information
agronholm committed Feb 27, 2022
commit ebd23534213961b9ff0a417ad233b6cc6c0c265b
5 changes: 4 additions & 1 deletion Lib/asyncio/proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,13 +703,16 @@ async def sock_recvfrom(self, sock, bufsize):
return await self._proactor.recvfrom(sock, bufsize)

async def sock_recvfrom_into(self, sock, buf, nbytes=0):
if not nbytes:
nbytes = len(buf)

return await self._proactor.recvfrom_into(sock, buf, nbytes)

async def sock_sendall(self, sock, data):
return await self._proactor.send(sock, data)

async def sock_sendto(self, sock, data, address):
return await self._proactor.send(sock, data, 0, address)
return await self._proactor.sendto(sock, data, 0, address)

async def sock_connect(self, sock, address):
return await self._proactor.connect(sock, address)
Expand Down
3 changes: 3 additions & 0 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,9 @@ async def sock_recvfrom_into(self, sock, buf, nbytes=0):
_check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
if not nbytes:
nbytes = len(buf)

try:
return sock.recvfrom_into(buf, nbytes)
except (BlockingIOError, InterruptedError):
Expand Down
14 changes: 9 additions & 5 deletions Lib/test/test_asyncio/test_sock_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,12 @@ async def _basetest_datagram_recvfrom_into(self, server_address):
self.assertEqual(buf, data)
self.assertEqual(from_addr, server_address)

buf = bytearray(4096)
buf = bytearray(8192)
await self.loop.sock_sendto(sock, data, server_address)
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
sock, buf, 2048)
self.assertEqual(num_bytes, 2048)
self.assertEqual(buf[:2048], data[:2048])
sock, buf, 4096)
self.assertEqual(num_bytes, 4096)
self.assertEqual(buf[:4096], data[:4096])
self.assertEqual(from_addr, server_address)

def test_recvfrom_into(self):
Expand All @@ -430,7 +430,7 @@ async def _basetest_datagram_sendto_blocking(self, server_address):
data = b'\x01' * 4096
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
mock_sock = Mock(socket.socket)
mock_sock = Mock(sock)
mock_sock.gettimeout = sock.gettimeout
mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
mock_sock.fileno = sock.fileno
Expand All @@ -445,6 +445,10 @@ async def _basetest_datagram_sendto_blocking(self, server_address):
self.assertEqual(from_addr, server_address)

def test_sendto_blocking(self):
if sys.platform == 'win32':
if isinstance(self.loop, asyncio.ProactorEventLoop):
raise unittest.SkipTest('Not relevant to ProactorEventLoop')

with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_sendto_blocking(server_address))
Expand Down
169 changes: 116 additions & 53 deletions Modules/overlapped.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class _overlapped.Overlapped "OverlappedObject *" "&OverlappedType"
enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE,
TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE,
TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE, TYPE_READ_FROM,
TYPE_WRITE_TO};
TYPE_WRITE_TO, TYPE_READ_FROM_INTO};

typedef struct {
PyObject_HEAD
Expand All @@ -91,6 +91,17 @@ typedef struct {
struct sockaddr_in6 address;
int address_length;
} read_from;

/* Data used for reading from a connectionless socket:
TYPE_READ_FROM_INTO */
struct {
// A (number of bytes read, (host, port)) tuple
PyObject* result;
/* Buffer passed by the user */
Py_buffer user_buffer;
struct sockaddr_in6 address;
int address_length;
} read_from_into;
};
} OverlappedObject;

Expand Down Expand Up @@ -662,6 +673,16 @@ Overlapped_clear(OverlappedObject *self)
}
break;
}
case TYPE_READ_FROM_INTO: {
if (self->read_from.result) {
// We've received a message, free the result tuple.
Py_CLEAR(self->read_from.result);
}
if (self->read_from_into.user_buffer.obj) {
PyBuffer_Release(&self->read_from_into.user_buffer);
}
break;
}
case TYPE_WRITE:
case TYPE_WRITE_TO:
case TYPE_READINTO: {
Expand Down Expand Up @@ -914,6 +935,30 @@ _overlapped_Overlapped_getresult_impl(OverlappedObject *self, BOOL wait)

Py_INCREF(self->read_from.result);
return self->read_from.result;
case TYPE_READ_FROM_INTO:
// unparse the address
addr = unparse_address((SOCKADDR*)&self->read_from_into.address,
self->read_from_into.address_length);

if (addr == NULL) {
return NULL;
}

// The result is a two item tuple: (number of bytes read, address)
self->read_from_into.result = PyTuple_New(2);
if (self->read_from_into.result == NULL) {
Py_CLEAR(addr);
return NULL;
}

// first item: number of bytes read
PyTuple_SET_ITEM(self->read_from_into.result, 0,
PyLong_FromUnsignedLong((unsigned long)transferred));
// second item: address
PyTuple_SET_ITEM(self->read_from_into.result, 1, addr);

Py_INCREF(self->read_from_into.result);
return self->read_from_into.result;
default:
return PyLong_FromUnsignedLong((unsigned long) transferred);
}
Expand Down Expand Up @@ -1053,45 +1098,6 @@ do_WSARecv(OverlappedObject *self, HANDLE handle,
}
}

static PyObject *
do_WSARecvFrom(OverlappedObject *self, HANDLE handle,
PyObject *bufobj, DWORD buflen, DWORD flags)
{
DWORD nread;
WSABUF wsabuf;
int ret;
DWORD err;

wsabuf.buf = PyBytes_AS_STRING(bufobj);
wsabuf.len = buflen;

self->type = TYPE_READ_FROM;
self->handle = handle;
self->read_from.allocated_buffer = bufobj;
memset(&self->read_from.address, 0, sizeof(self->read_from.address));
self->read_from.address_length = sizeof(self->read_from.address);

Py_BEGIN_ALLOW_THREADS
ret = WSARecvFrom((SOCKET)handle, &wsabuf, 1, &nread, &flags,
(SOCKADDR*)&self->read_from.address,
&self->read_from.address_length,
&self->overlapped, NULL);
Py_END_ALLOW_THREADS

self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
switch(err) {
case ERROR_BROKEN_PIPE:
mark_as_completed(&self->overlapped);
return SetFromWindowsErr(err);
case ERROR_SUCCESS:
case ERROR_MORE_DATA:
case ERROR_IO_PENDING:
Py_RETURN_NONE;
default:
self->type = TYPE_NOT_STARTED;
return SetFromWindowsErr(err);
}
}

/*[clinic input]
_overlapped.Overlapped.WSARecv
Expand Down Expand Up @@ -1821,7 +1827,40 @@ _overlapped_Overlapped_WSARecvFrom_impl(OverlappedObject *self,
return NULL;
}

return do_WSARecvFrom(self, handle, buf, size, flags);
DWORD nread;
WSABUF wsabuf;
int ret;
DWORD err;

wsabuf.buf = PyBytes_AS_STRING(buf);
wsabuf.len = size;

self->type = TYPE_READ_FROM;
self->handle = handle;
self->read_from.allocated_buffer = buf;
memset(&self->read_from.address, 0, sizeof(self->read_from.address));
self->read_from.address_length = sizeof(self->read_from.address);

Py_BEGIN_ALLOW_THREADS
ret = WSARecvFrom((SOCKET)handle, &wsabuf, 1, &nread, &flags,
(SOCKADDR*)&self->read_from.address,
&self->read_from.address_length,
&self->overlapped, NULL);
Py_END_ALLOW_THREADS

self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
switch (err) {
case ERROR_BROKEN_PIPE:
mark_as_completed(&self->overlapped);
return SetFromWindowsErr(err);
case ERROR_SUCCESS:
case ERROR_MORE_DATA:
case ERROR_IO_PENDING:
Py_RETURN_NONE;
default:
self->type = TYPE_NOT_STARTED;
return SetFromWindowsErr(err);
}
}


Expand All @@ -1848,26 +1887,50 @@ _overlapped_Overlapped_WSARecvFromInto_impl(OverlappedObject *self,
return NULL;
}

Py_buffer buffer;
if (!PyArg_Parse(bufobj, "y*", &buffer))
if (!PyArg_Parse(bufobj, "y*", &self->read_from_into.user_buffer))
return NULL;

#if SIZEOF_SIZE_T > SIZEOF_LONG
if (buffer.len > (Py_ssize_t)ULONG_MAX) {
PyBuffer_Release(&buffer);
if (self->read_from_into.user_buffer.len > (Py_ssize_t)ULONG_MAX) {
PyBuffer_Release(&self->read_from_into.user_buffer);
PyErr_SetString(PyExc_ValueError, "buffer too large");
return NULL;
}
#endif
if (buffer.len < size) {
PyBuffer_Release(&buffer);
PyErr_SetString(PyExc_ValueError,
"nbytes is greater than the length of the buffer");
return NULL;
}

PyBuffer_Release(&buffer);
return do_WSARecvFrom(self, handle, bufobj, size, flags);
DWORD nread;
WSABUF wsabuf;
int ret;
DWORD err;

wsabuf.buf = self->read_from_into.user_buffer.buf;
wsabuf.len = size;

self->type = TYPE_READ_FROM_INTO;
self->handle = handle;
memset(&self->read_from_into.address, 0, sizeof(self->read_from_into.address));
self->read_from_into.address_length = sizeof(self->read_from_into.address);

Py_BEGIN_ALLOW_THREADS
ret = WSARecvFrom((SOCKET)handle, &wsabuf, 1, &nread, &flags,
(SOCKADDR*)&self->read_from_into.address,
&self->read_from_into.address_length,
&self->overlapped, NULL);
Py_END_ALLOW_THREADS

self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS);
switch (err) {
case ERROR_BROKEN_PIPE:
mark_as_completed(&self->overlapped);
return SetFromWindowsErr(err);
case ERROR_SUCCESS:
case ERROR_MORE_DATA:
case ERROR_IO_PENDING:
Py_RETURN_NONE;
default:
self->type = TYPE_NOT_STARTED;
return SetFromWindowsErr(err);
}
}


Expand Down
0