10000 gh-132983: Minor fixes and clean up for the _zstd module by serhiy-storchaka · Pull Request #134930 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

gh-132983: Minor fixes and clean up for the _zstd module #134930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 1, 2025
52 changes: 50 additions & 2 deletions Lib/test/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,18 +1239,37 @@ def test_train_dict_c(self):
# argument wrong type
with self.assertRaises(TypeError):
_zstd.train_dict({}, (), 100)
with self.assertRaises(TypeError):
_zstd.train_dict(bytearray(), (), 100)
with self.assertRaises(TypeError):
_zstd.train_dict(b'', 99, 100)
with self.assertRaises(TypeError):
_zstd.train_dict(b'', [], 100)
with self.assertRaises(TypeError):
_zstd.train_dict(b'', (), 100.1)
with self.assertRaises(TypeError):
_zstd.train_dict(b'', (99.1,), 100)
with self.assertRaises(ValueError): 10000
_zstd.train_dict(b'abc', (4, -1), 100)
with self.assertRaises(ValueError):
_zstd.train_dict(b'abc', (2,), 100)
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (99,), 100)

# size > size_t
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (2**64+1,), 100)
_zstd.train_dict(b'', (2**1000,), 100)
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (-2**1000,), 100)

# dict_size <= 0
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (), 0)
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (), -1)

with self.assertRaises(ZstdError):
_zstd.train_dict(b'', (), 1)

def test_finalize_dict_c(self):
with self.assertRaises(TypeError):
Expand All @@ -1259,22 +1278,51 @@ def test_finalize_dict_c(self):
# argument wrong type
with self.assertRaises(TypeError):
_zstd.finalize_dict({}, b'', (), 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)

with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5)
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5)
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5)

# size > size_t
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5)
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5)
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5)

# dict_size <= 0
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5)
with self.assertRaises(OverflowError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5)
with self.assertRaises(OverflowError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5)

with self.assertRaises(OverflowError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000)
with self.assertRaises(OverflowError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000)

with self.assertRaises(ZstdError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5)

def test_train_buffer_protocol_samples(self):
def _nbytes(dat):
Expand Down
45 changes: 22 additions & 23 deletions Modules/_zstd/_zstdmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ set_zstd_error(const _zstd_state* const state,
char *msg;
assert(ZSTD_isError(zstd_ret));

if (state == NULL) {
return;
}
switch (type) {
case ERR_DECOMPRESS:
msg = "Unable to decompress Zstandard data: %s";
Expand Down Expand Up @@ -174,7 +177,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
Py_ssize_t sizes_sum;
Py_ssize_t i;

chunks_number = Py_SIZE(samples_sizes);
chunks_number = PyTuple_GET_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
Expand All @@ -188,20 +191,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
return -1;
}

sizes_sum = 0;
sizes_sum = PyBytes_GET_SIZE(samples_bytes);
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i));
(*chunk_sizes)[i] = size;
if (size == (size_t)-1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
goto sum_error;
}
return -1;
}
sizes_sum += (*chunk_sizes)[i];
if ((size_t)sizes_sum < size) {
goto sum_error;
}
sizes_sum -= size;
}

if (sizes_sum != Py_SIZE(samples_bytes)) {
if (sizes_sum != 0) {
sum_error:
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the "
"concatenation's size.");
Expand Down Expand Up @@ -257,7 +264,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,

/* Train the dictionary */
char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes);
char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
const char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
Py_BEGIN_ALLOW_THREADS
zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size,
samples_buffer,
Expand Down Expand Up @@ -507,17 +514,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
{
_zstd_state* mod_state = get_zstd_state(module);

if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
PyErr_SetString(PyExc_ValueError,
"The two arguments should be CompressionParameter and "
"DecompressionParameter types.");
return NULL;
}
Comment on lines -510 to -515
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity check on removing the type checks here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type checks are already handled by argument clinic, so these are actually redundant.


Py_XSETREF(
mod_state->CParameter_type, (PyTypeObject*)Py_NewRef(c_parameter_type));
Py_XSETREF(
mod_state->DParameter_type, (PyTypeObject*)Py_NewRef(d_parameter_type));
Py_INCREF(c_parameter_type);
Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type);
Py_INCREF(d_parameter_type);
Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type);

Py_RETURN_NONE;
}
Expand Down Expand Up @@ -580,7 +580,6 @@ do { \
return -1;
}
if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) {
Py_DECREF(mod_state->ZstdError);
return -1;
}

Expand Down
43 changes: 16 additions & 27 deletions Modules/_zstd/compressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ _zstd_set_c_level(ZstdCompressor *self, int level)
/* Check error */
if (ZSTD_isError(zstd_ret)) {
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
return -1;
}
set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret);
return -1;
}
Expand Down Expand Up @@ -203,16 +200,16 @@ _get_CDict(ZstdDict *self, int compressionLevel)
goto error;
}

/* Add PyCapsule object to self->c_dicts */
ret = PyDict_SetItem(self->c_dicts, level, capsule);
/* Add PyCapsule object to self->c_dicts if it is not already present. */
PyObject *result;
ret = PyDict_SetDefaultRef(self->c_dicts, level, capsule, &result);
if (ret < 0) {
goto error;
}
Py_DECREF(capsule);
capsule = result;
}
else {
/* ZSTD_CDict instance already exists */
cdict = PyCapsule_GetPointer(capsule, NULL);
}
cdict = PyCapsule_GetPointer(capsule, NULL);
goto success;

error:
Expand Down Expand Up @@ -272,11 +269,7 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
int type, ret;

/* Check ZstdDict */
ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type);
if (ret < 0) {
return -1;
}
else if (ret > 0) {
if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) {
/* When compressing, use undigested dictionary by default. */
zd = (ZstdDict*)dict;
type = DICT_TYPE_UNDIGESTED;
Expand All @@ -289,14 +282,14 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
/* Check (ZstdDict, type) */
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) {
/* Check ZstdDict */
ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0),
(PyObject*)mod_state->ZstdDict_type);
if (ret < 0) {
return -1;
}
else if (ret > 0) {
/* type == -1 may indicate an error. */
if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0),
mod_state->ZstdDict_type) &&
PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
{
type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
if (type == -1 && PyErr_Occurred()) {
return -1;
}
if (type == DICT_TYPE_DIGESTED
|| type == DICT_TYPE_UNDIGESTED
|| type == DICT_TYPE_PREFIX)
Expand Down Expand Up @@ -481,9 +474,7 @@ compress_lock_held(ZstdCompressor *self, Py_buffer *data,
/* Check error */
if (ZSTD_isError(zstd_ret)) {
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
}
set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
goto error;
}

Expand Down Expand Up @@ -553,9 +544,7 @@ compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data)
/* Check error */
if (ZSTD_isError(zstd_ret)) {
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
}
set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
goto error;
}

Expand Down
49 changes: 21 additions & 28 deletions Modules/_zstd/decompressor.c
7593
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,23 @@ _get_DDict(ZstdDict *self)
assert(PyMutex_IsLocked(&self->lock));
ZSTD_DDict *ret;

/* Already created */
if (self->d_dict != NULL) {
return self->d_dict;
}

if (self->d_dict == NULL) {
/* Create ZSTD_DDict instance from dictionary content */
Py_BEGIN_ALLOW_THREADS
ret = ZSTD_createDDict(self->dict_buffer, self->dict_len);
Py_END_ALLOW_THREADS
self->d_dict = ret;

if (self->d_dict == NULL) {
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
PyErr_SetString(mod_state->ZstdError,
"Failed to create a ZSTD_DDict instance from "
"Zstandard dictionary content.");
if (self->d_dict != NULL) {
ZSTD_freeDDict(ret);
}
else {
self->d_dict = ret;
if (self->d_dict == NULL) {
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
PyErr_SetString(mod_state->ZstdError,
"Failed to create a ZSTD_DDict instance from "
"Zstandard dictionary content.");
}
}
}
}
Expand Down Expand Up @@ -189,11 +188,7 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
int type, ret;

/* Check ZstdDict */
ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type);
if (ret < 0) {
return -1;
}
else if (ret > 0) {
if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) {
/* When decompressing, use digested dictionary by default. */
zd = (ZstdDict*)dict;
type = DICT_TYPE_DIGESTED;
Expand All @@ -206,14 +201,14 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
/* Check (ZstdDict, type) */
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) {
/* Check ZstdDict */
ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0),
(PyObject*)mod_state->ZstdDict_type);
if (ret < 0) {
return -1;
}
else if (ret > 0) {
/* type == -1 may indicate an error. */
if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0),
mod_state->ZstdDict_type) &&
PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
{
type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
if (type == -1 && PyErr_Occurred()) {
return -1;
}
if (type == DICT_TYPE_DIGESTED
|| type == DICT_TYPE_UNDIGESTED
|| type == DICT_TYPE_PREFIX)
Expand Down Expand Up @@ -282,9 +277,7 @@ decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in,
/* Check error */
if (ZSTD_isError(zstd_ret)) {
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret);
}
set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret);
goto error;
}

Expand Down
Loading
0