8000 Update multiprocessing from CPython 3.12.3 by youknowone · Pull Request #5263 · RustPython/RustPython · GitHub
[go: up one dir, main page]

Skip to content

Update multiprocessing from CPython 3.12.3 #5263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 227 additions & 27 deletions Lib/multiprocessing/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

__all__ = [ 'Client', 'Listener', 'Pipe', 'wait' ]

import errno
import io
import os
import sys
Expand Down Expand Up @@ -73,11 +74,6 @@ def arbitrary_address(family):
if family == 'AF_INET':
return ('localhost', 0)
elif family == 'AF_UNIX':
# Prefer abstract sockets if possible to avoid problems with the address
# size. When coding portable applications, some implementations have
# sun_path as short as 92 bytes in the sockaddr_un struct.
if util.abstract_sockets_supported:
return f"\0listener-{os.getpid()}-{next(_mmap_counter)}"
return tempfile.mktemp(prefix='listener-', dir=util.get_temp_dir())
elif family == 'AF_PIPE':
return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
Expand Down Expand Up @@ -188,10 +184,9 @@ def send_bytes(self, buf, offset=0, size=None):
self._check_closed()
self._check_writable()
m = memoryview(buf)
# HACK for byte-indexing of non-bytewise buffers (e.g. array.array)
if m.itemsize > 1:
m = memoryview(bytes(m))
n = len(m)
m = m.cast('B')
n = m.nbytes
if offset < 0:
raise ValueError("offset is negative")
if n < offset:
Expand Down Expand Up @@ -277,12 +272,22 @@ class PipeConnection(_ConnectionBase):
with FILE_FLAG_OVERLAPPED.
"""
_got_empty_message = False
_send_ov = None

def _close(self, _CloseHandle=_winapi.CloseHandle):
ov = self._send_ov
if ov is not None:
# Interrupt WaitForMultipleObjects() in _send_bytes()
ov.cancel()
_CloseHandle(self._handle)

def _send_bytes(self, buf):
if self._send_ov is not None:
# A connection should only be used by a single thread
raise ValueError("concurrent send_bytes() calls "
"are not supported")
ov, err = _winapi.WriteFile(self._handle, buf, overlapped=True)
self._send_ov = ov
try:
if err == _winapi.ERROR_IO_PENDING:
waitres = _winapi.WaitForMultipleObjects(
Expand All @@ -292,7 +297,13 @@ def _send_bytes(self, buf):
ov.cancel()
raise
finally:
self._send_ov = None
nwritten, err = ov.GetOverlappedResult(True)
if err == _winapi.ERROR_OPERATION_ABORTED:
# close() was called by another thread while
# WaitForMultipleObjects() was waiting for the overlapped
# operation.
raise OSError(errno.EPIPE, "handle is closed")
assert err == 0
assert nwritten == len(buf)

Expand Down Expand Up @@ -465,8 +476,9 @@ def accept(self):
'''
if self._listener is None:
raise OSError('listener is closed')

c = self._listener.accept()
if self._authkey:
if self._authkey is not None:
deliver_challenge(c, self._authkey)
answer_challenge(c, self._authkey)
return c
Expand Down Expand Up @@ -728,39 +740,227 @@ def PipeClient(address):
# Authentication stuff
#

MESSAGE_LENGTH = 20
MESSAGE_LENGTH = 40 # MUST be > 20

CHALLENGE = b'#CHALLENGE#'
WELCOME = b'#WELCOME#'
FAILURE = b'#FAILURE#'
_CHALLENGE = b'#CHALLENGE#'
_WELCOME = b'#WELCOME#'
_FAILURE = b'#FAILURE#'

def deliver_challenge(connection, authkey):
# multiprocessing.connection Authentication Handshake Protocol Description
# (as documented for reference after reading the existing code)
# =============================================================================
#
# On Windows: native pipes with "overlapped IO" are used to send the bytes,
# instead of the length prefix SIZE scheme described below. (ie: the OS deals
# with message sizes for us)
#
# Protocol error behaviors:
#
# On POSIX, any failure to receive the length prefix into SIZE, for SIZE greater
# than the requested maxsize to receive, or receiving fewer than SIZE bytes
# results in the connection being closed and auth to fail.
#
# On Windows, receiving too few bytes is never a low level _recv_bytes read
# error, receiving too many will trigger an error only if receive maxsize
# value was larger than 128 OR the if the data arrived in smaller pieces.
#
# Serving side Client side
# ------------------------------ ---------------------------------------
# 0. Open a connection on the pipe.
# 1. Accept connection.
# 2. Random 20+ bytes -> MESSAGE
# Modern servers always send
# more than 20 bytes and include
# a {digest} prefix on it with
# their preferred HMAC digest.
# Legacy ones send ==20 bytes.
# 3. send 4 byte length (net order)
# prefix followed by:
# b'#CHALLENGE#' + MESSAGE
# 4. Receive 4 bytes, parse as network byte
# order integer. If it is -1, receive an
# additional 8 bytes, parse that as network
# byte order. The result is the length of
# the data that follows -> SIZE.
# 5. Receive min(SIZE, 256) bytes -> M1
# 6. Assert that M1 starts with:
# b'#CHALLENGE#'
# 7. Strip that prefix from M1 into -> M2
# 7.1. Parse M2: if it is exactly 20 bytes in
# length this indicates a legacy server
# supporting only HMAC-MD5. Otherwise the
# 7.2. preferred digest is looked up from an
# expected "{digest}" prefix on M2. No prefix
# or unsupported digest? <- AuthenticationError
# 7.3. Put divined algorithm name in -> D_NAME
# 8. Compute HMAC-D_NAME of AUTHKEY, M2 -> C_DIGEST
# 9. Send 4 byte length prefix (net order)
# followed by C_DIGEST bytes.
# 10. Receive 4 or 4+8 byte length
# prefix (#4 dance) -> SIZE.
# 11. Receive min(SIZE, 256) -> C_D.
# 11.1. Parse C_D: legacy servers
# accept it as is, "md5" -> D_NAME
# 11.2. modern servers check the length
# of C_D, IF it is 16 bytes?
# 11.2.1. "md5" -> D_NAME
# and skip to step 12.
# 11.3. longer? expect and parse a "{digest}"
# prefix into -> D_NAME.
# Strip the prefix and store remaining
# bytes in -> C_D.
# 11.4. Don't like D_NAME? <- AuthenticationError
# 12. Compute HMAC-D_NAME of AUTHKEY,
# MESSAGE into -> M_DIGEST.
# 13. Compare M_DIGEST == C_D:
# 14a: Match? Send length prefix &
# b'#WELCOME#'
# <- RETURN
# 14b: Mismatch? Send len prefix &
# b'#FAILURE#'
# <- CLOSE & AuthenticationError
# 15. Receive 4 or 4+8 byte length prefix (net
# order) again as in #4 into -> SIZE.
# 16. Receive min(SIZE, 256) bytes -> M3.
# 17. Compare M3 == b'#WELCOME#':
# 17a. Match? <- RETURN
# 17b. Mismatch? <- CLOSE & AuthenticationError
#
# If this RETURNed, the connection remains open: it has been authenticated.
#
# Length prefixes are used consistently. Even on the legacy protocol, this
# was good fortune and allowed us to evolve the protocol by using the length
# of the opening challenge or length of the returned digest as a signal as
# to which protocol the other end supports.

_ALLOWED_DIGESTS = frozenset(
{b'md5', b'sha256', b'sha384', b'sha3_256', b'sha3_384'})
_MAX_DIGEST_LEN = max(len(_) for _ in _ALLOWED_DIGESTS)

# Old hmac-md5 only server versions from Python <=3.11 sent a message of this
# length. It happens to not match the length of any supported digest so we can
# use a message of this length to indicate that we should work in backwards
# compatible md5-only mode without a {digest_name} prefix on our response.
_MD5ONLY_MESSAGE_LENGTH = 20
_MD5_DIGEST_LEN = 16
_LEGACY_LENGTHS = (_MD5ONLY_MESSAGE_LENGTH, _MD5_DIGEST_LEN)


def _get_digest_name_and_payload(message: bytes) -> (str, bytes):
"""Returns a digest name and the payload for a response hash.

If a legacy protocol is detected based on the message length
or contents the digest name returned will be empty to indicate
legacy mode where MD5 and no digest prefix should be sent.
"""
# modern message format: b"{digest}payload" longer than 20 bytes
# legacy message format: 16 or 20 byte b"payload"
if len(message) in _LEGACY_LENGTHS:
# Either this was a legacy server challenge, or we're processing
# a reply from a legacy client that sent an unprefixed 16-byte
# HMAC-MD5 response. All messages using the modern protocol will
# be longer than either of these lengths.
return '', message
if (message.startswith(b'{') and
(curly := message.find(b'}', 1, _MAX_DIGEST_LEN+2)) > 0):
digest = message[1:curly]
if digest in _ALLOWED_DIGESTS:
payload = message[curly+1:]
return digest.decode('ascii'), payload
raise AuthenticationError(
'unsupported message length, missing digest prefix, '
f'or unsupported digest: {message=}')


def _create_response(authkey, message):
"""Create a MAC based on authkey and message

The MAC algorithm defaults to HMAC-MD5, unless MD5 is not available or
the message has a '{digest_name}' prefix. For legacy HMAC-MD5, the response
is the raw MAC, otherwise the response is prefixed with '{digest_name}',
e.g. b'{sha256}abcdefg...'

Note: The MAC protects the entire message including the digest_name prefix.
"""
import hmac
digest_name = _get_digest_name_and_payload(message)[0]
# The MAC protects the entire message: digest header and payload.
if not digest_name:
# Legacy server without a {digest} prefix on message.
# Generate a legacy non-prefixed HMAC-MD5 reply.
try:
return hmac.new(authkey, message, 'md5').digest()
except ValueError:
# HMAC-MD5 is not available (FIPS mode?), fall back to
# HMAC-SHA2-256 modern protocol. The legacy server probably
# doesn't support it and will reject us anyways. :shrug:
digest_name = 'sha256'
# Modern protocol, indicate the digest used in the reply.
response = hmac.new(authkey, message, digest_name).digest()
return b'{%s}%s' % (digest_name.encode('ascii'), response)


def _verify_challenge(authkey, message, response):
"""Verify MAC challenge

If our message did not include a digest_name prefix, the client is allowed
to select a stronger digest_name from _ALLOWED_DIGESTS.

In case our message is prefixed, a client cannot downgrade to a weaker
algorithm, because the MAC is calculated over the entire message
including the '{digest_name}' prefix.
"""
import hmac
response_digest, response_mac = _get_digest_name_and_payload(response)
response_digest = response_digest or 'md5'
try:
expected = hmac.new(authkey, message, response_digest).digest()
except ValueError:
raise AuthenticationError(f'{response_digest=} unsupported')
if len(expected) != len(response_mac):
raise AuthenticationError(
f'expected {response_digest!r} of length {len(expected)} '
f'got {len(response_mac)}')
if not hmac.compare_digest(expected, response_mac):
raise AuthenticationError('digest received was wrong')


def deliver_challenge(connection, authkey: bytes, digest_name='sha256'):
if not isinstance(authkey, bytes):
raise ValueError(
"Authkey must be bytes, not {0!s}".format(type(authkey)))
assert MESSAGE_LENGTH > _MD5ONLY_MESSAGE_LENGTH, "protocol constraint"
message = os.urandom(MESSAGE_LENGTH)
connection.send_bytes(CHALLENGE + message)
digest = hmac.new(authkey, message, 'md5').digest()
message = b'{%s}%s' % (digest_name.encode('ascii'), message)
# Even when sending a challenge to a legacy client that does not support
# digest prefixes, they'll take the entire thing as a challenge and
# respond to it with a raw HMAC-MD5.
connection.send_bytes(_CHALLENGE + message)
response = connection.recv_bytes(256) # reject large message
if response == digest:
connection.send_bytes(WELCOME)
try:
_verify_challenge(authkey, message, response)
except AuthenticationError:
connection.send_bytes(_FAILURE)
raise
else:
connection.send_bytes(FAILURE)
raise AuthenticationError('digest received was wrong')
connection.send_bytes(_WELCOME)

def answer_challenge(connection, authkey):
import hmac

def answer_challenge(connection, authkey: bytes):
if not isinstance(authkey, bytes):
raise ValueError(
"Authkey must be bytes, not {0!s}".format(type(authkey)))
message = connection.recv_bytes(256) # reject large message
assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message
message = message[len(CHALLENGE):]
digest = hmac.new(authkey, message, 'md5').digest()
if not message.startswith(_CHALLENGE):
raise AuthenticationError(
f'Protocol error, expected challenge: {message=}')
message = message[len(_CHALLENGE):]
if len(message) < _MD5ONLY_MESSAGE_LENGTH:
raise AuthenticationError('challenge too short: {len(message)} bytes')
digest = _create_response(authkey, message)
connection.send_bytes(digest)
response = connection.recv_bytes(256) # reject large message
if response != WELCOME:
if response != _WELCOME:
raise AuthenticationError('digest sent was rejected')

#
Expand Down Expand Up @@ -943,7 +1143,7 @@ def wait(object_list, timeout=None):
return ready

#
# Make connection and socket objects sharable if possible
# Make connection and socket objects shareable if possible
#

if sys.platform == 'win32':
Expand Down
15 changes: 15 additions & 0 deletions Lib/multiprocessing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ class Process(process.BaseProcess):
def _Popen(process_obj):
10000 return _default_context.get_context().Process._Popen(process_obj)

@staticmethod
def _after_fork():
return _default_context.get_context().Process._after_fork()

class DefaultContext(BaseContext):
Process = Process

Expand Down Expand Up @@ -254,6 +258,7 @@ def get_start_method(self, allow_none=False):
return self._actual_context._name

def get_all_start_methods(self):
"""Returns a list of the supported start methods, default first."""
if sys.platform == 'win32':
return ['spawn']
else:
Expand Down Expand Up @@ -283,6 +288,11 @@ def _Popen(process_obj):
from .popen_spawn_posix import Popen
return Popen(process_obj)

@staticmethod
def _after_fork():
# process is spawned, nothing to do
pass

class ForkServerProcess(process.BaseProcess):
_start_method = 'forkserver'
@staticmethod
Expand Down Expand Up @@ -326,6 +336,11 @@ def _Popen(process_obj):
from .popen_spawn_win32 import Popen
return Popen(process_obj)

@staticmethod
def _after_fork():
# process is spawned, nothing to do
pass

class SpawnContext(BaseContext):
_name = 'spawn'
Process = SpawnProcess
Expand Down
2 changes: 1 addition & 1 deletion Lib/multiprocessing/forkserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _stop_unlocked(self):

def set_forkserver_preload(self, modules_names):
'''Set list of module names to try to load in forkserver process.'''
if not all(type(mod) is str for mod in self._preload_modules):
if not all(type(mod) is str for mod in modules_names):
raise TypeError('module_names must be a list of strings')
self._preload_modules = modules_names

Expand Down
Loading
Loading
0