8000 gh-133885: Disallow sharing zstd (de)compressor contexts by emmatyping · Pull Request #134253 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

gh-133885: Disallow sharing zstd (de)compressor contexts #134253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 11 additions & 40 deletions Lib/test/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2430,83 +2430,54 @@ def test_buffer_protocol(self):
self.assertEqual(f.write(arr), LENGTH)
self.assertEqual(f.tell(), LENGTH)

@unittest.skip("it fails for now, see gh-133885")

class FreeThreadingMethodTests(unittest.TestCase):

@unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_compress_locking(self):
def test_compressor_cannot_share(self):
input = b'a'* (16*_1K)
num_threads = 8

comp = ZstdCompressor()
parts = []
for _ in range(num_threads):
res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK)
if res:
parts.append(res)
rest1 = comp.flush()
expected = b''.join(parts) + rest1

comp = ZstdCompressor()
output = []
def run_method(method, input_data, output_data):
res = method(input_data, ZstdCompressor.FLUSH_BLOCK)
if res:
output_data.append(res)
def run_method(method, input_data):
with self.assertRaises(RuntimeError):
method(input_data, ZstdCompressor.FLUSH_BLOCK)
threads = []

for i in range(num_threads):
thread = threading.Thread(target=run_method, args=(comp.compress, input, output))
thread = threading.Thread(target=run_method, args=(comp.compress, input))

threads.append(thread)

with threading_helper.start_threads(threads):
pass

rest2 = comp.flush()
self.assertEqual(rest1, rest2)
actual = b''.join(output) + rest2
self.assertEqual(expected, actual)

@unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_decompress_locking(self):
def test_decompressor_cannot_share(self):
input = compress(b'a'* (16*_1K))
num_threads = 8
# to ensure we decompress over multiple calls, set maxsize
window_size = _1K * 16//num_threads

decomp = ZstdDecompressor()
parts = []
for _ in range(num_threads):
res = decomp.decompress(input, window_size)
if res:
parts.append(res)
expected = b''.join(parts)

comp = ZstdDecompressor()
output = []
def run_method(method, input_data, output_data):
res = method(input_data, window_size)
if res:
output_data.append(res)
def run_method(method, input_data):
with self.assertRaises(RuntimeError):
method(input_data, window_size)
threads = []

for i in range(num_threads):
thread = threading.Thread(target=run_method, args=(comp.decompress, input, output))
thread = threading.Thread(target=run_method, args=(comp.decompress, input))

threads.append(thread)

with threading_helper.start_threads(threads):
pass

actual = b''.join(output)
self.assertEqual(expected, actual)



if __name__ == "__main__":
unittest.main()
17 changes: 17 additions & 0 deletions Modules/_zstd/_zstdmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,21 @@ extern void
set_parameter_error(const _zstd_state* const state, int is_compress,
int key_v, int value_v);

static inline int
check_object_shared(PyObject *ob, char *type)
8000 Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably not be the right way to implement this check? Access from different threads could happen regardless of free-threading builds.

I think the concept the extension module might want to enforce is something like getting the current thread id (we have a python c api for this IIRC) upon first "use" of the object, saving that in the object state, and checking that current thread id == that thread id for subsequent uses.

first use could be construction... but unless the zstandard C API requires that, it may be better to consider first use the first time that ZStdCompressor is actually used by a C API. This allows for the pattern of:

A compressor or decompressor is created by one thread - and added to a thread pool or work queue to be actually used and exhausted - exclusively - in another thread.

{
#if defined(Py_GIL_DISABLED)
if (!_Py_IsOwnedByCurrentThread(ob))
{
PyErr_Format(PyExc_RuntimeError,
"%s cannot be shared across multiple threads.",
type);
return 1;
}
return 0;
#else
return 0;
#endif
}

#endif // !ZSTD_MODULE_H
19 changes: 12 additions & 7 deletions Modules/_zstd/compressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,12 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
{
PyObject *ret;

/* Check we are on the same thread as the compressor was created */
if (check_object_shared((PyObject *)self, "ZstdCompressor") > 0)
{
return NULL;
}

/* Check mode value */
if (mode != ZSTD_e_continue &&
mode != ZSTD_e_flush &&
Expand All @@ -587,9 +593,6 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
return NULL;
}

/* Thread-safe code */
Py_BEGIN_CRITICAL_SECTION(self);

/* Compress */
if (self->use_multithread && mode == ZSTD_e_continue) {
ret = compress_mt_continue_impl(self, data);
Expand All @@ -607,7 +610,6 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
Py_END_CRITICAL_SECTION();

return ret;
}
Expand All @@ -632,6 +634,12 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
{
PyObject *ret;

/* Check we are on the same thread as the compressor was created */
if (check_object_shared((PyObject *)self, "ZstdCompressor") > 0)
{
return NULL;
}

/* Check mode value */
if (mode != ZSTD_e_end && mode != ZSTD_e_flush) {
PyErr_SetString(PyExc_ValueError,
Expand All @@ -641,8 +649,6 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
return NULL;
}

/* Thread-safe code */
Py_BEGIN_CRITICAL_SECTION(self);
ret = compress_impl(self, NULL, mode);

if (ret) {
Expand All @@ -654,7 +660,6 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
Py_END_CRITICAL_SECTION();

return ret;
}
Expand Down
15 changes: 11 additions & 4 deletions Modules/_zstd/decompressor.c
6D40
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,12 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
{
PyObject *ret;

/* Check we are on the same thread as the decompressor was created */
if (check_object_shared((PyObject *)self, "ZstdDecompressor") > 0)
{
return NULL;
}

if (!self->eof) {
return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES);
}
Expand Down Expand Up @@ -692,11 +698,12 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self,
/*[clinic end generated code: output=a4302b3c940dbec6 input=6463dfdf98091caa]*/
{
PyObject *ret;
/* Thread-safe code */
Py_BEGIN_CRITICAL_SECTION(self);

/* Check we are on the same thread as the decompressor was created */
if (check_object_shared((PyObject *)self, "ZstdDecompressor") > 0)
{
return NULL;
}
ret = stream_decompress(self, data, max_length);
Py_END_CRITICAL_SECTION();
return ret;
}

Expand Down
Loading
0