8000 gh-132983: Refactor shared code in train_dict and finalize_dict (GH-1… · python/cpython@c64a214 · GitHub
[go: up one dir, main page]

Skip to content

Commit c64a214

Browse files
authored
gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432)
Refactor shared code in train_dict and finalize_dict
1 parent 0a68068 commit c64a214

File tree

1 file changed

+55
-68
lines changed

1 file changed

+55
-68
lines changed

Modules/_zstd/_zstdmodule.c

Lines changed: 55 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
172172
return (_zstd_state *)state;
173173
}
174174

175+
static Py_ssize_t
176+
calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
177+
size_t **chunk_sizes)
178+
{
179+
Py_ssize_t chunks_number;
180+
Py_ssize_t sizes_sum;
181+
Py_ssize_t i;
182+
183+
chunks_number = Py_SIZE(samples_sizes);
184+
if ((size_t) chunks_number > UINT32_MAX) {
185+
PyErr_Format(PyExc_ValueError,
186+
"The number of samples should be <= %u.", UINT32_MAX);
187+
return -1;
188+
}
189+
190+
/* Prepare chunk_sizes */
191+
*chunk_sizes = PyMem_New(size_t, chunks_number);
192+
if (*chunk_sizes == NULL) {
193+
PyErr_NoMemory();
194+
return -1;
195+
}
196+
197+
sizes_sum = 0;
198+
for (i = 0; i < chunks_number; i++) {
199+
PyObject *size = PyTuple_GetItem(samples_sizes, i);
200+
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
201+
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
202+
PyErr_Format(PyExc_ValueError,
203+
"Items in samples_sizes should be an int "
204+
"object, with a value between 0 and %u.", SIZE_MAX);
205+
return -1;
206+
}
207+
sizes_sum += (*chunk_sizes)[i];
208+
}
209+
210+
if (sizes_sum != Py_SIZE(samples_bytes)) {
211+
PyErr_SetString(PyExc_ValueError,
212+
"The samples size tuple doesn't match the concatenation's size.");
213+
return -1;
214+
}
215+
return chunks_number;
216+
}
217+
175218

176219
/*[clinic input]
177220
_zstd.train_dict
@@ -192,54 +235,25 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
192235
PyObject *samples_sizes, Py_ssize_t dict_size)
193236
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
194237
{
195-
// TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
196-
// are pretty similar. We should see if we can refactor them to share that code.
197-
Py_ssize_t chunks_number;
198-
size_t *chunk_sizes = NULL;
199238
PyObject *dst_dict_bytes = NULL;
239+
size_t *chunk_sizes = NULL;
240+
Py_ssize_t chunks_number;
200241
size_t zstd_ret;
201-
Py_ssize_t sizes_sum;
202-
Py_ssize_t i;
203242

204243
/* Check arguments */
205244
if (dict_size <= 0) {
206245
PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
207246
return NULL;
208247
}
209248

210-
chunks_number = Py_SIZE(samples_sizes);
211-
if ((size_t) chunks_number > UINT32_MAX) {
212-
PyErr_Format(PyExc_ValueError,
213-
"The number of samples should be <= %u.", UINT32_MAX);
249+
/* Check that the samples are valid and get their sizes */
250+
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
251+
&chunk_sizes);
252+
if (chunks_number < 0)
253+
{
214254
return NULL;
215255
}
216256

217-
/* Prepare chunk_sizes */
218-
chunk_sizes = PyMem_New(size_t, chunks_number);
219-
if (chunk_sizes == NULL) {
220-
PyErr_NoMemory();
221-
goto error;
222-
}
223-
224-
sizes_sum = 0;
225-
for (i = 0; i < chunks_number; i++) {
226-
PyObject *size = PyTuple_GetItem(samples_sizes, i);
227-
chunk_sizes[i] = PyLong_AsSize_t(size);
228-
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
229-
PyErr_Format(PyExc_ValueError,
230-
"Items in samples_sizes should be an int "
231-
"object, with a value between 0 and %u.", SIZE_MAX);
232-
goto error;
233-
}
234-
sizes_sum += chunk_sizes[i];
235-
}
236-
237-
if (sizes_sum != Py_SIZE(samples_bytes)) {
238-
PyErr_SetString(PyExc_ValueError,
239-
"The samples size tuple doesn't match the concatenation's size.");
240-
goto error;
241-
}
242-
243257
/* Allocate dict buffer */
244258
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
245259
if (dst_dict_bytes == NULL) {
@@ -307,48 +321,21 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
307321
PyObject *dst_dict_bytes = NULL;
308322
size_t zstd_ret;
309323
ZDICT_params_t params;
310-
Py_ssize_t sizes_sum;
311-
Py_ssize_t i;
312324

313325
/* Check arguments */
314326
if (dict_size <= 0) {
315327
PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
316328
return NULL;
317329
}
318330

319-
chunks_number = Py_SIZE(samples_sizes);
320-
if ((size_t) chunks_number > UINT32_MAX) {
321-
PyErr_Format(PyExc_ValueError,
322-
"The number of samples should be <= %u.", UINT32_MAX);
331+
/* Check that the samples are valid and get their sizes */
332+
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
333+
&chunk_sizes);
334+
if (chunks_number < 0)
335+
{
323336
return NULL;
324337
}
325338

326-
/* Prepare chunk_sizes */
327-
chunk_sizes = PyMem_New(size_t, chunks_number);
328-
if (chunk_sizes == NULL) {
329-
PyErr_NoMemory();
330-
goto error;
331-
}
332-
333-
sizes_sum = 0;
334-
for (i = 0; i < chunks_number; i++) {
335-
PyObject *size = PyTuple_GetItem(samples_sizes, i);
336-
chunk_sizes[i] = PyLong_AsSize_t(size);
337-
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
338-
PyErr_Format(PyExc_ValueError,
339-
"Items in samples_sizes should be an int "
340-
"object, with a value between 0 and %u.", SIZE_MAX);
341-
goto error;
342-
}
343-
sizes_sum += chunk_sizes[i];
344-
}
345-
346-
if (sizes_sum != Py_SIZE(samples_bytes)) {
347-
PyErr_SetString(PyExc_ValueError,
348-
"The samples size tuple doesn't match the concatenation's size.");
349-
goto error;
350-
}
351-
352339
/* Allocate dict buffer */
353340
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
354341
if (dst_dict_bytes == NULL) {

0 commit comments

Comments
 (0)
0