8000 [3.14] gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432) by miss-islington · Pull Request #134442 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

[3.14] gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432) #134442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 55 additions & 68 deletions Modules/_zstd/_zstdmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
return (_zstd_state *)state;
}

static Py_ssize_t
calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
size_t **chunk_sizes)
{
Py_ssize_t chunks_number;
Py_ssize_t sizes_sum;
Py_ssize_t i;

chunks_number = Py_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
return -1;
}

/* Prepare chunk_sizes */
*chunk_sizes = PyMem_New(size_t, chunks_number);
if (*chunk_sizes == NULL) {
PyErr_NoMemory();
return -1;
}

sizes_sum = 0;
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
return -1;
}
sizes_sum += (*chunk_sizes)[i];
}

if (sizes_sum != Py_SIZE(samples_bytes)) {
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the concatenation's size.");
return -1;
}
return chunks_number;
}


/*[clinic input]
_zstd.train_dict
Expand All @@ -192,54 +235,25 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
PyObject *samples_sizes, Py_ssize_t dict_size)
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
{
// TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
// are pretty similar. We should see if we can refactor them to share that code.
Py_ssize_t chunks_number;
size_t *chunk_sizes = NULL;
PyObject *dst_dict_bytes = NULL;
size_t *chunk_sizes = NULL;
Py_ssize_t chunks_number;
size_t zstd_ret;
Py_ssize_t sizes_sum;
Py_ssize_t i;

/* Check arguments */
if (dict_size <= 0) {
PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
return NULL;
}

chunks_number = Py_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
/* Check that the samples are valid and get their sizes */
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
&chunk_sizes);
if (chunks_number < 0)
{
return NULL;
}

/* Prepare chunk_sizes */
chunk_sizes = PyMem_New(size_t, chunks_number);
if (chunk_sizes == NULL) {
PyErr_NoMemory();
goto error;
}

sizes_sum = 0;
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
chunk_sizes[i] = PyLong_AsSize_t(size);
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
goto error;
}
sizes_sum += chunk_sizes[i];
}

if (sizes_sum != Py_SIZE(samples_bytes)) {
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the concatenation's size.");
goto error;
}

/* Allocate dict buffer */
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
if (dst_dict_bytes == NULL) {
Expand Down Expand Up @@ -307,48 +321,21 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
PyObject *dst_dict_bytes = NULL;
size_t zstd_ret;
ZDICT_params_t params;
Py_ssize_t sizes_sum;
Py_ssize_t i;

/* Check arguments */
if (dict_size <= 0) {
PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
return NULL;
}

chunks_number = Py_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
/* Check that the samples are valid and get their sizes */
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
&chunk_sizes);
if (chunks_number < 0)
{
return NULL;
}

/* Prepare chunk_sizes */
chunk_sizes = PyMem_New(size_t, chunks_number);
if (chunk_sizes == NULL) {
PyErr_NoMemory();
goto error;
}

sizes_sum = 0;
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
chunk_sizes[i] = PyLong_AsSize_t(size);
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
goto error;
}
sizes_sum += chunk_sizes[i];
}

if (sizes_sum != Py_SIZE(samples_bytes)) {
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the concatenation's size.");
goto error;
}

/* Allocate dict buffer */
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
if (dst_dict_bytes == NULL) {
Expand Down
Loading
0