diff --git a/Doc/license.rst b/Doc/license.rst index 90783e3e31a69d..480414bb84c4f2 100644 --- a/Doc/license.rst +++ b/Doc/license.rst @@ -1132,3 +1132,40 @@ The file is distributed under the 2-Clause BSD License:: THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +Zstandard bindings +------------------ + +Zstandard bindings in :file:`Modules/_zstd` and :file:`Lib/compression/zstd` +are based on code from the +`pyzstd library `_, copyright Ma Lin and +contributors. The pyzstd code is distributed under the 3-Clause BSD License:: + + Copyright (c) 2020-present, Ma Lin and contributors. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Include/internal/pycore_global_objects_fini_generated.h b/Include/internal/pycore_global_objects_fini_generated.h index e412db1de68f8b..40e639b14442a4 100644 --- a/Include/internal/pycore_global_objects_fini_generated.h +++ b/Include/internal/pycore_global_objects_fini_generated.h @@ -834,6 +834,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(bytes_per_sep)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(c_call)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(c_exception)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(c_parameter_type)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(c_return)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(cached_datetime_module)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(cached_statements)); @@ -887,6 +888,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(count)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(covariant)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(cwd)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(d_parameter_type)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(data)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(database)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(day)); @@ -901,6 +903,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(deterministic)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(device)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(dict)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(dict_content)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(dictcomp)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(difference_update)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(digest)); @@ -966,6 +969,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(follow_symlinks)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(format)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(format_spec)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(frame_buffer)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(from_param)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(fromlist)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(fromtimestamp)); @@ -1024,6 +1028,8 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(intersection)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(interval)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(io)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(is_compress)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(is_raw)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(is_running)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(is_struct)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(isatty)); @@ -1147,6 +1153,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(overlapped)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(owner)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(pages)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(parameter)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(parent)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(password)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(path)); @@ -1308,6 +1315,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(write_through)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(year)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(zdict)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(zstd_dict)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_SINGLETON(strings).ascii[0]); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_SINGLETON(strings).ascii[1]); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_SINGLETON(strings).ascii[2]); diff --git a/Include/internal/pycore_global_strings.h b/Include/internal/pycore_global_strings.h index 2a6c2065af6bb9..b6fbc219daaab1 100644 --- a/Include/internal/pycore_global_strings.h +++ b/Include/internal/pycore_global_strings.h @@ -325,6 +325,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(bytes_per_sep) STRUCT_FOR_ID(c_call) STRUCT_FOR_ID(c_exception) + STRUCT_FOR_ID(c_parameter_type) STRUCT_FOR_ID(c_return) STRUCT_FOR_ID(cached_datetime_module) STRUCT_FOR_ID(cached_statements) @@ -378,6 +379,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(count) STRUCT_FOR_ID(covariant) STRUCT_FOR_ID(cwd) + STRUCT_FOR_ID(d_parameter_type) STRUCT_FOR_ID(data) STRUCT_FOR_ID(database) STRUCT_FOR_ID(day) @@ -392,6 +394,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(deterministic) STRUCT_FOR_ID(device) STRUCT_FOR_ID(dict) + STRUCT_FOR_ID(dict_content) STRUCT_FOR_ID(dictcomp) STRUCT_FOR_ID(difference_update) STRUCT_FOR_ID(digest) @@ -457,6 +460,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(follow_symlinks) STRUCT_FOR_ID(format) STRUCT_FOR_ID(format_spec) + STRUCT_FOR_ID(frame_buffer) STRUCT_FOR_ID(from_param) STRUCT_FOR_ID(fromlist) STRUCT_FOR_ID(fromtimestamp) @@ -515,6 +519,8 @@ struct _Py_global_strings { STRUCT_FOR_ID(intersection) STRUCT_FOR_ID(interval) STRUCT_FOR_ID(io) + STRUCT_FOR_ID(is_compress) + STRUCT_FOR_ID(is_raw) STRUCT_FOR_ID(is_running) STRUCT_FOR_ID(is_struct) STRUCT_FOR_ID(isatty) @@ -638,6 +644,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(overlapped) STRUCT_FOR_ID(owner) STRUCT_FOR_ID(pages) + STRUCT_FOR_ID(parameter) STRUCT_FOR_ID(parent) STRUCT_FOR_ID(password) STRUCT_FOR_ID(path) @@ -799,6 +806,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(write_through) STRUCT_FOR_ID(year) STRUCT_FOR_ID(zdict) + STRUCT_FOR_ID(zstd_dict) } identifiers; struct { PyASCIIObject _ascii; diff --git a/Include/internal/pycore_runtime_init_generated.h b/Include/internal/pycore_runtime_init_generated.h index 2368157a4fd18b..b2e4029dafefb5 100644 --- a/Include/internal/pycore_runtime_init_generated.h +++ b/Include/internal/pycore_runtime_init_generated.h @@ -832,6 +832,7 @@ extern "C" { INIT_ID(bytes_per_sep), \ INIT_ID(c_call), \ INIT_ID(c_exception), \ + INIT_ID(c_parameter_type), \ INIT_ID(c_return), \ INIT_ID(cached_datetime_module), \ INIT_ID(cached_statements), \ @@ -885,6 +886,7 @@ extern "C" { INIT_ID(count), \ INIT_ID(covariant), \ INIT_ID(cwd), \ + INIT_ID(d_parameter_type), \ INIT_ID(data), \ INIT_ID(database), \ INIT_ID(day), \ @@ -899,6 +901,7 @@ extern "C" { INIT_ID(deterministic), \ INIT_ID(device), \ INIT_ID(dict), \ + INIT_ID(dict_content), \ INIT_ID(dictcomp), \ INIT_ID(difference_update), \ INIT_ID(digest), \ @@ -964,6 +967,7 @@ extern "C" { INIT_ID(follow_symlinks), \ INIT_ID(format), \ INIT_ID(format_spec), \ + INIT_ID(frame_buffer), \ INIT_ID(from_param), \ INIT_ID(fromlist), \ INIT_ID(fromtimestamp), \ @@ -1022,6 +1026,8 @@ extern "C" { INIT_ID(intersection), \ INIT_ID(interval), \ INIT_ID(io), \ + INIT_ID(is_compress), \ + INIT_ID(is_raw), \ INIT_ID(is_running), \ INIT_ID(is_struct), \ INIT_ID(isatty), \ @@ -1145,6 +1151,7 @@ extern "C" { INIT_ID(overlapped), \ INIT_ID(owner), \ INIT_ID(pages), \ + INIT_ID(parameter), \ INIT_ID(parent), \ INIT_ID(password), \ INIT_ID(path), \ @@ -1306,6 +1313,7 @@ extern "C" { INIT_ID(write_through), \ INIT_ID(year), \ INIT_ID(zdict), \ + INIT_ID(zstd_dict), \ } #define _Py_str_ascii_INIT { \ diff --git a/Include/internal/pycore_unicodeobject_generated.h b/Include/internal/pycore_unicodeobject_generated.h index 72c3346328a552..a8f3ce4d9a6970 100644 --- a/Include/internal/pycore_unicodeobject_generated.h +++ b/Include/internal/pycore_unicodeobject_generated.h @@ -1088,6 +1088,10 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(c_parameter_type); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_ID(c_return); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); @@ -1300,6 +1304,10 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(d_parameter_type); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_ID(data); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); @@ -1356,6 +1364,10 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(dict_content); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_ID(dictcomp); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); @@ -1616,6 +1628,10 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(frame_buffer); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_ID(from_param); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); @@ -1848,6 +1864,14 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(is_compress); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(is_raw); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_ID(is_running); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); @@ -2340,6 +2364,10 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(parameter); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_ID(parent); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); @@ -2984,6 +3012,10 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(zstd_dict); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_STR(empty); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py new file mode 100644 index 00000000000000..731da1a9598392 --- /dev/null +++ b/Lib/compression/zstd/__init__.py @@ -0,0 +1,286 @@ +"""Python bindings to Zstandard (zstd) compression library, the API style is +similar to Python's bz2/lzma/zlib modules. +""" + +__all__ = ( + # From this file + "compressionLevel_values", + "get_frame_info", + "CParameter", + "DParameter", + "Strategy", + "finalize_dict", + "train_dict", + "zstd_support_multithread", + "compress", + "decompress", + # From _zstd + "ZstdCompressor", + "ZstdDecompressor", + "ZstdDict", + "ZstdError", + "get_frame_size", + "zstd_version", + "zstd_version_info", + # From zstd.zstdfile + "open", + "ZstdFile", +) + +from collections import namedtuple +from enum import IntEnum +from functools import lru_cache + +from compression.zstd.zstdfile import ZstdFile, open +from _zstd import * + +import _zstd + + +_ZSTD_CStreamSizes = _zstd._ZSTD_CStreamSizes +_ZSTD_DStreamSizes = _zstd._ZSTD_DStreamSizes +_train_dict = _zstd._train_dict +_finalize_dict = _zstd._finalize_dict + + +# TODO(emmatyping): these should be dataclasses or some other class, not namedtuples + +# compressionLevel_values +_nt_values = namedtuple("values", ["default", "min", "max"]) +compressionLevel_values = _nt_values(*_zstd._compressionLevel_values) + + +_nt_frame_info = namedtuple("frame_info", ["decompressed_size", "dictionary_id"]) + + +def get_frame_info(frame_buffer): + """Get zstd frame information from a frame header. + + Parameter + frame_buffer: A bytes-like object. It should starts from the beginning of + a frame, and needs to include at least the frame header (6 to + 18 bytes). + + Return a two-items namedtuple: (decompressed_size, dictionary_id) + + If decompressed_size is None, decompressed size is unknown. + + dictionary_id is a 32-bit unsigned integer value. 0 means dictionary ID was + not recorded in the frame header, the frame may or may not need a dictionary + to be decoded, and the ID of such a dictionary is not specified. + + It's possible to append more items to the namedtuple in the future.""" + + ret_tuple = _zstd._get_frame_info(frame_buffer) + return _nt_frame_info(*ret_tuple) + + +def _nbytes(dat): + if isinstance(dat, (bytes, bytearray)): + return len(dat) + with memoryview(dat) as mv: + return mv.nbytes + + +def train_dict(samples, dict_size): + """Train a zstd dictionary, return a ZstdDict object. + + Parameters + samples: An iterable of samples, a sample is a bytes-like object + represents a file. + dict_size: The dictionary's maximum size, in bytes. + """ + # Check argument's type + if not isinstance(dict_size, int): + raise TypeError('dict_size argument should be an int object.') + + # Prepare data + chunks = [] + chunk_sizes = [] + for chunk in samples: + chunks.append(chunk) + chunk_sizes.append(_nbytes(chunk)) + + chunks = b''.join(chunks) + if not chunks: + raise ValueError("The samples are empty content, can't train dictionary.") + + # samples_bytes: samples be stored concatenated in a single flat buffer. + # samples_size_list: a list of each sample's size. + # dict_size: size of the dictionary, in bytes. + dict_content = _train_dict(chunks, chunk_sizes, dict_size) + + return ZstdDict(dict_content) + + +def finalize_dict(zstd_dict, samples, dict_size, level): + """Finalize a zstd dictionary, return a ZstdDict object. + + Given a custom content as a basis for dictionary, and a set of samples, + finalize dictionary by adding headers and statistics according to the zstd + dictionary format. + + You may compose an effective dictionary content by hand, which is used as + basis dictionary, and use some samples to finalize a dictionary. The basis + dictionary can be a "raw content" dictionary, see is_raw parameter in + ZstdDict.__init__ method. + + Parameters + zstd_dict: A ZstdDict object, basis dictionary. + samples: An iterable of samples, a sample is a bytes-like object + represents a file. + dict_size: The dictionary's maximum size, in bytes. + level: The compression level expected to use in production. The + statistics for each compression level differ, so tuning the + dictionary for the compression level can help quite a bit. + """ + + # Check arguments' type + if not isinstance(zstd_dict, ZstdDict): + raise TypeError('zstd_dict argument should be a ZstdDict object.') + if not isinstance(dict_size, int): + raise TypeError('dict_size argument should be an int object.') + if not isinstance(level, int): + raise TypeError('level argument should be an int object.') + + # Prepare data + chunks = [] + chunk_sizes = [] + for chunk in samples: + chunks.append(chunk) + chunk_sizes.append(_nbytes(chunk)) + + chunks = b''.join(chunks) + if not chunks: + raise ValueError("The samples are empty content, can't finalize dictionary.") + + # custom_dict_bytes: existing dictionary. + # samples_bytes: samples be stored concatenated in a single flat buffer. + # samples_size_list: a list of each sample's size. + # dict_size: maximal size of the dictionary, in bytes. + # compression_level: compression level expected to use in production. + dict_content = _finalize_dict(zstd_dict.dict_content, + chunks, chunk_sizes, + dict_size, level) + + return _zstd.ZstdDict(dict_content) + +def compress(data, level=None, options=None, zstd_dict=None): + """Compress a block of data, return a bytes object of zstd compressed data. + + Refer to ZstdCompressor's docstring for a description of the + optional arguments *level*, *options*, and *zstd_dict*. + + For incremental compression, use an ZstdCompressor instead. + """ + comp = ZstdCompressor(level=level, options=options, zstd_dict=zstd_dict) + return comp.compress(data, ZstdCompressor.FLUSH_FRAME) + +def decompress(data, zstd_dict=None, options=None): + """Decompress one or more frames of data. + + Refer to ZstdDecompressor's docstring for a description of the + optional arguments *zstd_dict*, *options*. + + For incremental decompression, use an ZstdDecompressor instead. + """ + results = [] + while True: + decomp = ZstdDecompressor(options=options, zstd_dict=zstd_dict) + try: + res = decomp.decompress(data) + except ZstdError: + if results: + break # Leftover data is not a valid LZMA/XZ stream; ignore it. + else: + raise # Error on the first iteration; bail out. + results.append(res) + if not decomp.eof: + raise ZstdError("Compressed data ended before the " + "end-of-stream marker was reached") + data = decomp.unused_data + if not data: + break + return b"".join(results) + +class _UnsupportedCParameter: + def __set_name__(self, _, name): + self.name = name + + def __get__(self, *_, **__): + msg = ("%s CParameter not available, zstd version is %s.") % ( + self.name, + zstd_version, + ) + raise NotImplementedError(msg) + + +class CParameter(IntEnum): + """Compression parameters""" + + compressionLevel = _zstd._ZSTD_c_compressionLevel + windowLog = _zstd._ZSTD_c_windowLog + hashLog = _zstd._ZSTD_c_hashLog + chainLog = _zstd._ZSTD_c_chainLog + searchLog = _zstd._ZSTD_c_searchLog + minMatch = _zstd._ZSTD_c_minMatch + targetLength = _zstd._ZSTD_c_targetLength + strategy = _zstd._ZSTD_c_strategy + + targetCBlockSize = _UnsupportedCParameter() + + enableLongDistanceMatching = _zstd._ZSTD_c_enableLongDistanceMatching + ldmHashLog = _zstd._ZSTD_c_ldmHashLog + ldmMinMatch = _zstd._ZSTD_c_ldmMinMatch + ldmBucketSizeLog = _zstd._ZSTD_c_ldmBucketSizeLog + ldmHashRateLog = _zstd._ZSTD_c_ldmHashRateLog + + contentSizeFlag = _zstd._ZSTD_c_contentSizeFlag + checksumFlag = _zstd._ZSTD_c_checksumFlag + dictIDFlag = _zstd._ZSTD_c_dictIDFlag + + nbWorkers = _zstd._ZSTD_c_nbWorkers + jobSize = _zstd._ZSTD_c_jobSize + overlapLog = _zstd._ZSTD_c_overlapLog + + @lru_cache(maxsize=None) + def bounds(self): + """Return lower and upper bounds of a compression parameter, both inclusive.""" + # 1 means compression parameter + return _zstd._get_param_bounds(1, self.value) + + +class DParameter(IntEnum): + """Decompression parameters""" + + windowLogMax = _zstd._ZSTD_d_windowLogMax + + @lru_cache(maxsize=None) + def bounds(self): + """Return lower and upper bounds of a decompression parameter, both inclusive.""" + # 0 means decompression parameter + return _zstd._get_param_bounds(0, self.value) + + +class Strategy(IntEnum): + """Compression strategies, listed from fastest to strongest. + + Note : new strategies _might_ be added in the future, only the order + (from fast to strong) is guaranteed. + """ + + fast = _zstd._ZSTD_fast + dfast = _zstd._ZSTD_dfast + greedy = _zstd._ZSTD_greedy + lazy = _zstd._ZSTD_lazy + lazy2 = _zstd._ZSTD_lazy2 + btlazy2 = _zstd._ZSTD_btlazy2 + btopt = _zstd._ZSTD_btopt + btultra = _zstd._ZSTD_btultra + btultra2 = _zstd._ZSTD_btultra2 + + +# Set CParameter/DParameter types for validity check +_zstd._set_parameter_types(CParameter, DParameter) + +zstd_support_multithread = CParameter.nbWorkers.bounds() != (0, 0) diff --git a/Lib/compression/zstd/zstdfile.py b/Lib/compression/zstd/zstdfile.py new file mode 100644 index 00000000000000..1ca60fe5677454 --- /dev/null +++ b/Lib/compression/zstd/zstdfile.py @@ -0,0 +1,378 @@ +import builtins +import io + +from os import PathLike + +from _zstd import ZstdCompressor, ZstdDecompressor, _ZSTD_DStreamSizes, ZstdError +from compression._common import _streams + +__all__ = ("ZstdFile", "open") + +_ZSTD_DStreamOutSize = _ZSTD_DStreamSizes[1] + +_MODE_CLOSED = 0 +_MODE_READ = 1 +_MODE_WRITE = 2 + + +class ZstdFile(_streams.BaseStream): + """A file object providing transparent zstd (de)compression. + + A ZstdFile can act as a wrapper for an existing file object, or refer + directly to a named file on disk. + + Note that ZstdFile provides a *binary* file interface - data read is + returned as bytes, and data to be written should be an object that + supports the Buffer Protocol. + """ + + _READER_CLASS = _streams.DecompressReader + + FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK + FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME + + def __init__( + self, + filename, + mode="r", + *, + level=None, + options=None, + zstd_dict=None, + ): + """Open a zstd compressed file in binary mode. + + filename can be either an actual file name (given as a str, bytes, or + PathLike object), in which case the named file is opened, or it can be + an existing file object to read from or write to. + + mode can be "r" for reading (default), "w" for (over)writing, "x" for + creating exclusively, or "a" for appending. These can equivalently be + given as "rb", "wb", "xb" and "ab" respectively. + + Parameters + level: The compression level to use, defaults to ZSTD_CLEVEL_DEFAULT. Note, + in read mode (decompression), compression level is not supported. + options: A dict object, containing advanced compression + parameters. + zstd_dict: A ZstdDict object, pre-trained dictionary for compression / + decompression. + """ + self._fp = None + self._closefp = False + self._mode = _MODE_CLOSED + + # Read or write mode + if mode in ("r", "rb"): + if not isinstance(options, (type(None), dict)): + raise TypeError( + ( + "In read mode (decompression), options argument " + "should be a dict object, that represents decompression " + "options." + ) + ) + if level: + raise TypeError("level argument should only be passed when writing.") + mode_code = _MODE_READ + elif mode in ("w", "wb", "a", "ab", "x", "xb"): + if not isinstance(level, (type(None), int)): + raise TypeError(("level argument should be an int object.")) + if not isinstance(options, (type(None), dict)): + raise TypeError(("options argument should be an dict object.")) + mode_code = _MODE_WRITE + self._compressor = ZstdCompressor( + level=level, options=options, zstd_dict=zstd_dict + ) + self._pos = 0 + else: + raise ValueError("Invalid mode: {!r}".format(mode)) + + # File object + if isinstance(filename, (str, bytes, PathLike)): + if "b" not in mode: + mode += "b" + self._fp = builtins.open(filename, mode) + self._closefp = True + elif hasattr(filename, "read") or hasattr(filename, "write"): + self._fp = filename + else: + raise TypeError(("filename must be a str, bytes, file or PathLike object")) + self._mode = mode_code + + if self._mode == _MODE_READ: + raw = self._READER_CLASS( + self._fp, + ZstdDecompressor, + trailing_error=ZstdError, + zstd_dict=zstd_dict, + options=options, + ) + self._buffer = io.BufferedReader(raw) + + def close(self): + """Flush and close the file. + + May be called more than once without error. Once the file is + closed, any other operation on it will raise a ValueError. + """ + # Nop if already closed + if self._fp is None: + return + try: + if self._mode == _MODE_READ: + if hasattr(self, "_buffer") and self._buffer: + self._buffer.close() + self._buffer = None + elif self._mode == _MODE_WRITE: + self.flush(self.FLUSH_FRAME) + self._compressor = None + finally: + self._mode = _MODE_CLOSED + try: + if self._closefp: + self._fp.close() + finally: + self._fp = None + self._closefp = False + + def write(self, data): + """Write a bytes-like object to the file. + + Returns the number of uncompressed bytes written, which is + always the length of data in bytes. Note that due to buffering, + the file on disk may not reflect the data written until .flush() + or .close() is called. + """ + self._check_can_write() + if isinstance(data, (bytes, bytearray)): + length = len(data) + else: + # accept any data that supports the buffer protocol + data = memoryview(data) + length = data.nbytes + + compressed = self._compressor.compress(data) + self._fp.write(compressed) + self._pos += length + return length + + def flush(self, mode=FLUSH_BLOCK): + """Flush remaining data to the underlying stream. + + The mode argument can be ZstdFile.FLUSH_BLOCK, ZstdFile.FLUSH_FRAME. + Abuse of this method will reduce compression ratio, use it only when + necessary. + + If the program is interrupted afterwards, all data can be recovered. + To ensure saving to disk, also need to use os.fsync(fd). + + This method does nothing in reading mode. + """ + if self._mode == _MODE_READ: + return + self._check_not_closed() + if mode not in (self.FLUSH_BLOCK, self.FLUSH_FRAME): + raise ValueError("mode argument wrong value, it should be " + "ZstdCompressor.FLUSH_FRAME or " + "ZstdCompressor.FLUSH_BLOCK.") + if self._compressor.last_mode == mode: + return + # Flush zstd block/frame, and write. + data = self._compressor.flush(mode) + self._fp.write(data) + if hasattr(self._fp, "flush"): + self._fp.flush() + + def read(self, size=-1): + """Read up to size uncompressed bytes from the file. + + If size is negative or omitted, read until EOF is reached. + Returns b"" if the file is already at EOF. + """ + if size is None: + size = -1 + self._check_can_read() + return self._buffer.read(size) + + def read1(self, size=-1): + """Read up to size uncompressed bytes, while trying to avoid + making multiple reads from the underlying stream. Reads up to a + buffer's worth of data if size is negative. + + Returns b"" if the file is at EOF. + """ + self._check_can_read() + if size < 0: + # Note this should *not* be io.DEFAULT_BUFFER_SIZE. + # ZSTD_DStreamOutSize is the minimum amount to read guaranteeing + # a full block is read. + size = _ZSTD_DStreamOutSize + return self._buffer.read1(size) + + def readinto(self, b): + """Read bytes into b. + + Returns the number of bytes read (0 for EOF). + """ + self._check_can_read() + return self._buffer.readinto(b) + + def readinto1(self, b): + """Read bytes into b, while trying to avoid making multiple reads + from the underlying stream. + + Returns the number of bytes read (0 for EOF). + """ + self._check_can_read() + return self._buffer.readinto1(b) + + def readline(self, size=-1): + """Read a line of uncompressed bytes from the file. + + The terminating newline (if present) is retained. If size is + non-negative, no more than size bytes will be read (in which + case the line may be incomplete). Returns b'' if already at EOF. + """ + self._check_can_read() + return self._buffer.readline(size) + + def seek(self, offset, whence=io.SEEK_SET): + """Change the file position. + + The new position is specified by offset, relative to the + position indicated by whence. Possible values for whence are: + + 0: start of stream (default): offset must not be negative + 1: current stream position + 2: end of stream; offset must not be positive + + Returns the new file position. + + Note that seeking is emulated, so depending on the arguments, + this operation may be extremely slow. + """ + self._check_can_read() + + # BufferedReader.seek() checks seekable + return self._buffer.seek(offset, whence) + + def peek(self, size=-1): + """Return buffered data without advancing the file position. + + Always returns at least one byte of data, unless at EOF. + The exact number of bytes returned is unspecified. + """ + # Relies on the undocumented fact that BufferedReader.peek() always + # returns at least one byte (except at EOF) + self._check_can_read() + return self._buffer.peek(size) + + def __next__(self): + ret = self._buffer.readline() + if ret: + return ret + raise StopIteration + + def tell(self): + """Return the current file position.""" + self._check_not_closed() + if self._mode == _MODE_READ: + return self._buffer.tell() + elif self._mode == _MODE_WRITE: + return self._pos + + def fileno(self): + """Return the file descriptor for the underlying file.""" + self._check_not_closed() + return self._fp.fileno() + + @property + def name(self): + self._check_not_closed() + return self._fp.name + + @property + def mode(self): + return 'wb' if self._mode == _MODE_WRITE else 'rb' + + @property + def closed(self): + """True if this file is closed.""" + return self._mode == _MODE_CLOSED + + def seekable(self): + """Return whether the file supports seeking.""" + return self.readable() and self._buffer.seekable() + + def readable(self): + """Return whether the file was opened for reading.""" + self._check_not_closed() + return self._mode == _MODE_READ + + def writable(self): + """Return whether the file was opened for writing.""" + self._check_not_closed() + return self._mode == _MODE_WRITE + + +# Copied from lzma module +def open( + filename, + mode="rb", + *, + level=None, + options=None, + zstd_dict=None, + encoding=None, + errors=None, + newline=None, +): + """Open a zstd compressed file in binary or text mode. + + filename can be either an actual file name (given as a str, bytes, or + PathLike object), in which case the named file is opened, or it can be an + existing file object to read from or write to. + + The mode parameter can be "r", "rb" (default), "w", "wb", "x", "xb", "a", + "ab" for binary mode, or "rt", "wt", "xt", "at" for text mode. + + The level, options, and zstd_dict parameters specify the settings the same + as ZstdFile. + + When using read mode (decompression), the options parameter is a dict + representing advanced decompression options. The level parameter is not + supported in this case. When using write mode (compression), only one of + level, an int representing the compression level, or options, a dict + representing advanced compression options, may be passed. In both modes, + zstd_dict is a ZstdDict instance containing a trained Zstandard dictionary. + + For binary mode, this function is equivalent to the ZstdFile constructor: + ZstdFile(filename, mode, ...). In this case, the encoding, errors and + newline parameters must not be provided. + + For text mode, an ZstdFile object is created, and wrapped in an + io.TextIOWrapper instance with the specified encoding, error handling + behavior, and line ending(s). + """ + + if "t" in mode: + if "b" in mode: + raise ValueError("Invalid mode: %r" % (mode,)) + else: + if encoding is not None: + raise ValueError("Argument 'encoding' not supported in binary mode") + if errors is not None: + raise ValueError("Argument 'errors' not supported in binary mode") + if newline is not None: + raise ValueError("Argument 'newline' not supported in binary mode") + + zstd_mode = mode.replace("t", "") + binary_file = ZstdFile( + filename, zstd_mode, level=level, options=options, zstd_dict=zstd_dict + ) + + if "t" in mode: + return io.TextIOWrapper(binary_file, encoding, errors, newline) + else: + return binary_file diff --git a/Lib/shutil.py b/Lib/shutil.py index 510ae8c6f22d59..ca0a2ea2f7fa8a 100644 --- a/Lib/shutil.py +++ b/Lib/shutil.py @@ -32,6 +32,13 @@ except ImportError: _LZMA_SUPPORTED = False +try: + from compression import zstd + del zstd + _ZSTD_SUPPORTED = True +except ImportError: + _ZSTD_SUPPORTED = False + _WINDOWS = os.name == 'nt' posix = nt = None if os.name == 'posix': @@ -1006,6 +1013,8 @@ def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0, tar_compression = 'bz2' elif _LZMA_SUPPORTED and compress == 'xz': tar_compression = 'xz' + elif _ZSTD_SUPPORTED and compress == 'zst': + tar_compression = 'zst' else: raise ValueError("bad value for 'compress', or compression format not " "supported : {0}".format(compress)) @@ -1134,6 +1143,10 @@ def _make_zipfile(base_name, base_dir, verbose=0, dry_run=0, _ARCHIVE_FORMATS['xztar'] = (_make_tarball, [('compress', 'xz')], "xz'ed tar-file") +if _ZSTD_SUPPORTED: + _ARCHIVE_FORMATS['zstdtar'] = (_make_tarball, [('compress', 'zst')], + "zstd'ed tar-file") + def get_archive_formats(): """Returns a list of supported formats for archiving and unarchiving. @@ -1174,7 +1187,7 @@ def make_archive(base_name, format, root_dir=None, base_dir=None, verbose=0, 'base_name' is the name of the file to create, minus any format-specific extension; 'format' is the archive format: one of "zip", "tar", "gztar", - "bztar", or "xztar". Or any other registered format. + "bztar", "zstdtar", or "xztar". Or any other registered format. 'root_dir' is a directory that will be the root directory of the archive; ie. we typically chdir into 'root_dir' before creating the @@ -1359,6 +1372,10 @@ def _unpack_tarfile(filename, extract_dir, *, filter=None): _UNPACK_FORMATS['xztar'] = (['.tar.xz', '.txz'], _unpack_tarfile, [], "xz'ed tar-file") +if _ZSTD_SUPPORTED: + _UNPACK_FORMATS['zstdtar'] = (['.tar.zst', '.tzst'], _unpack_tarfile, [], + "zstd'ed tar-file") + def _find_unpack_format(filename): for name, info in _UNPACK_FORMATS.items(): for extension in info[0]: diff --git a/Lib/tarfile.py b/Lib/tarfile.py index 82c5f6704cbd24..967c245c9c074d 100644 --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -399,7 +399,17 @@ def __init__(self, name, mode, comptype, fileobj, bufsize, self.exception = lzma.LZMAError else: self.cmp = lzma.LZMACompressor(preset=preset) - + elif comptype == "zst": + try: + from compression import zstd + except ImportError: + raise CompressionError("compression.zstd module is not available") from None + if mode == "r": + self.dbuf = b"" + self.cmp = zstd.ZstdDecompressor() + self.exception = zstd.ZstdError + else: + self.cmp = zstd.ZstdCompressor() elif comptype != "tar": raise CompressionError("unknown compression type %r" % comptype) @@ -591,6 +601,8 @@ def getcomptype(self): return "bz2" elif self.buf.startswith((b"\x5d\x00\x00\x80", b"\xfd7zXZ")): return "xz" + elif self.buf.startswith(b"\x28\xb5\x2f\xfd"): + return "zst" else: return "tar" @@ -1817,11 +1829,13 @@ def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs): 'r:gz' open for reading with gzip compression 'r:bz2' open for reading with bzip2 compression 'r:xz' open for reading with lzma compression + 'r:zst' open for reading with zstd compression 'a' or 'a:' open for appending, creating the file if necessary 'w' or 'w:' open for writing without compression 'w:gz' open for writing with gzip compression 'w:bz2' open for writing with bzip2 compression 'w:xz' open for writing with lzma compression + 'w:zst' open for writing with zstd compression 'x' or 'x:' create a tarfile exclusively without compression, raise an exception if the file is already created @@ -1831,16 +1845,20 @@ def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs): if the file is already created 'x:xz' create an lzma compressed tarfile, raise an exception if the file is already created + 'x:zst' create a zstd compressed tarfile, raise an exception + if the file is already created 'r|*' open a stream of tar blocks with transparent compression 'r|' open an uncompressed stream of tar blocks for reading 'r|gz' open a gzip compressed stream of tar blocks 'r|bz2' open a bzip2 compressed stream of tar blocks 'r|xz' open an lzma compressed stream of tar blocks + 'r|zst' open a zstd compressed stream of tar blocks 'w|' open an uncompressed stream for writing 'w|gz' open a gzip compressed stream for writing 'w|bz2' open a bzip2 compressed stream for writing 'w|xz' open an lzma compressed stream for writing + 'w|zst' open a zstd compressed stream for writing """ if not name and not fileobj: @@ -2006,12 +2024,48 @@ def xzopen(cls, name, mode="r", fileobj=None, preset=None, **kwargs): t._extfileobj = False return t + @classmethod + def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None, + zstd_dict=None, **kwargs): + """Open zstd compressed tar archive name for reading or writing. + Appending is not allowed. + """ + if mode not in ("r", "w", "x"): + raise ValueError("mode must be 'r', 'w' or 'x'") + + try: + from compression.zstd import ZstdFile, ZstdError + except ImportError: + raise CompressionError("compression.zstd module is not available") from None + + fileobj = ZstdFile( + fileobj or name, + mode, + level=level, + options=options, + zstd_dict=zstd_dict + ) + + try: + t = cls.taropen(name, mode, fileobj, **kwargs) + except (ZstdError, EOFError) as e: + fileobj.close() + if mode == 'r': + raise ReadError("not a zstd file") from e + raise + except: + fileobj.close() + raise + t._extfileobj = False + return t + # All *open() methods are registered here. OPEN_METH = { "tar": "taropen", # uncompressed tar "gz": "gzopen", # gzip compressed tar "bz2": "bz2open", # bzip2 compressed tar - "xz": "xzopen" # lzma compressed tar + "xz": "xzopen", # lzma compressed tar + "zst": "zstopen" # zstd compressed tar } #-------------------------------------------------------------------------- @@ -2963,6 +3017,9 @@ def main(): '.tbz': 'bz2', '.tbz2': 'bz2', '.tb2': 'bz2', + # zstd + '.zst': 'zst', + '.tzst': 'zst', } tar_mode = 'w:' + compressions[ext] if ext in compressions else 'w' tar_files = args.create diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 82f881094982f6..7a3c759404ff13 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -33,7 +33,7 @@ "is_resource_enabled", "requires", "requires_freebsd_version", "requires_gil_enabled", "requires_linux_version", "requires_mac_ver", "check_syntax_error", - "requires_gzip", "requires_bz2", "requires_lzma", + "requires_gzip", "requires_bz2", "requires_lzma", "requires_zstd", "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute", "requires_IEEE_754", "requires_zlib", "has_fork_support", "requires_fork", @@ -527,6 +527,13 @@ def requires_lzma(reason='requires lzma'): lzma = None return unittest.skipUnless(lzma, reason) +def requires_zstd(reason='requires zstd'): + try: + from compression import zstd + except ImportError: + zstd = None + return unittest.skipUnless(zstd, reason) + def has_no_debug_ranges(): try: import _testcapi diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index ed01163074a507..87991fbda4c7df 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -2153,6 +2153,10 @@ def test_unpack_archive_gztar(self): def test_unpack_archive_bztar(self): self.check_unpack_tarball('bztar') + @support.requires_zstd() + def test_unpack_archive_zstdtar(self): + self.check_unpack_tarball('zstdtar') + @support.requires_lzma() @unittest.skipIf(AIX and not _maxdataOK(), "AIX MAXDATA must be 0x20000000 or larger") def test_unpack_archive_xztar(self): diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index fcbaf854cc294f..2d9649237a9382 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -38,6 +38,10 @@ import lzma except ImportError: lzma = None +try: + from compression import zstd +except ImportError: + zstd = None def sha256sum(data): return sha256(data).hexdigest() @@ -48,6 +52,7 @@ def sha256sum(data): gzipname = os.path.join(TEMPDIR, "testtar.tar.gz") bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2") xzname = os.path.join(TEMPDIR, "testtar.tar.xz") +zstname = os.path.join(TEMPDIR, "testtar.tar.zst") tmpname = os.path.join(TEMPDIR, "tmp.tar") dotlessname = os.path.join(TEMPDIR, "testtar") @@ -90,6 +95,12 @@ class LzmaTest: open = lzma.LZMAFile if lzma else None taropen = tarfile.TarFile.xzopen +@support.requires_zstd() +class ZstdTest: + tarname = zstname + suffix = 'zst' + open = zstd.ZstdFile if zstd else None + taropen = tarfile.TarFile.zstopen class ReadTest(TarTest): @@ -271,6 +282,8 @@ class Bz2UstarReadTest(Bz2Test, UstarReadTest): class LzmaUstarReadTest(LzmaTest, UstarReadTest): pass +class ZstdUstarReadTest(ZstdTest, UstarReadTest): + pass class ListTest(ReadTest, unittest.TestCase): @@ -375,6 +388,8 @@ class Bz2ListTest(Bz2Test, ListTest): class LzmaListTest(LzmaTest, ListTest): pass +class ZstdListTest(ZstdTest, ListTest): + pass class CommonReadTest(ReadTest): @@ -837,6 +852,8 @@ class Bz2MiscReadTest(Bz2Test, MiscReadTestBase, unittest.TestCase): class LzmaMiscReadTest(LzmaTest, MiscReadTestBase, unittest.TestCase): pass +class ZstdMiscReadTest(ZstdTest, MiscReadTestBase, unittest.TestCase): + pass class StreamReadTest(CommonReadTest, unittest.TestCase): @@ -909,6 +926,9 @@ class Bz2StreamReadTest(Bz2Test, StreamReadTest): class LzmaStreamReadTest(LzmaTest, StreamReadTest): pass +class ZstdStreamReadTest(ZstdTest, StreamReadTest): + pass + class TarStreamModeReadTest(StreamModeTest, unittest.TestCase): def test_stream_mode_no_cache(self): @@ -925,6 +945,9 @@ class Bz2StreamModeReadTest(Bz2Test, TarStreamModeReadTest): class LzmaStreamModeReadTest(LzmaTest, TarStreamModeReadTest): pass +class ZstdStreamModeReadTest(ZstdTest, TarStreamModeReadTest): + pass + class DetectReadTest(TarTest, unittest.TestCase): def _testfunc_file(self, name, mode): try: @@ -986,6 +1009,8 @@ def test_detect_stream_bz2(self): class LzmaDetectReadTest(LzmaTest, DetectReadTest): pass +class ZstdDetectReadTest(ZstdTest, DetectReadTest): + pass class GzipBrokenHeaderCorrectException(GzipTest, unittest.TestCase): """ @@ -1666,6 +1691,8 @@ class Bz2WriteTest(Bz2Test, WriteTest): class LzmaWriteTest(LzmaTest, WriteTest): pass +class ZstdWriteTest(ZstdTest, WriteTest): + pass class StreamWriteTest(WriteTestBase, unittest.TestCase): @@ -1727,6 +1754,9 @@ class Bz2StreamWriteTest(Bz2Test, StreamWriteTest): class LzmaStreamWriteTest(LzmaTest, StreamWriteTest): decompressor = lzma.LZMADecompressor if lzma else None +class ZstdStreamWriteTest(ZstdTest, StreamWriteTest): + decompressor = zstd.ZstdDecompressor if zstd else None + class _CompressedWriteTest(TarTest): # This is not actually a standalone test. # It does not inherit WriteTest because it only makes sense with gz,bz2 @@ -2042,6 +2072,14 @@ def test_create_with_preset(self): tobj.add(self.file_path) +class ZstdCreateTest(ZstdTest, CreateTest): + + # Unlike gz and bz2, zstd uses the level keyword instead of compresslevel. + # It does not allow for level to be specified when reading. + def test_create_with_level(self): + with tarfile.open(tmpname, self.mode, level=1) as tobj: + tobj.add(self.file_path) + class CreateWithXModeTest(CreateTest): prefix = "x" @@ -2523,6 +2561,8 @@ class Bz2AppendTest(Bz2Test, AppendTestBase, unittest.TestCase): class LzmaAppendTest(LzmaTest, AppendTestBase, unittest.TestCase): pass +class ZstdAppendTest(ZstdTest, AppendTestBase, unittest.TestCase): + pass class LimitsTest(unittest.TestCase): @@ -2835,7 +2875,7 @@ def test_create_command_compressed(self): support.findfile('tokenize_tests-no-coding-cookie-' 'and-utf8-bom-sig-only.txt', subdir='tokenizedata')] - for filetype in (GzipTest, Bz2Test, LzmaTest): + for filetype in (GzipTest, Bz2Test, LzmaTest, ZstdTest): if not filetype.open: continue try: @@ -4257,7 +4297,7 @@ def setUpModule(): data = fobj.read() # Create compressed tarfiles. - for c in GzipTest, Bz2Test, LzmaTest: + for c in GzipTest, Bz2Test, LzmaTest, ZstdTest: if c.open: os_helper.unlink(c.tarname) testtarnames.append(c.tarname) diff --git a/Lib/test/test_zipfile/test_core.py b/Lib/test/test_zipfile/test_core.py index 7c8a82d821a020..e1ef600af7708c 100644 --- a/Lib/test/test_zipfile/test_core.py +++ b/Lib/test/test_zipfile/test_core.py @@ -23,7 +23,8 @@ from test.support import script_helper, os_helper from test.support import ( findfile, requires_zlib, requires_bz2, requires_lzma, - captured_stdout, captured_stderr, requires_subprocess, + requires_zstd, captured_stdout, captured_stderr, + requires_subprocess, is_emscripten ) from test.support.os_helper import ( TESTFN, unlink, rmtree, temp_dir, temp_cwd, fd_count, FakePath @@ -693,6 +694,10 @@ class LzmaTestsWithSourceFile(AbstractTestsWithSourceFile, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestsWithSourceFile(AbstractTestsWithSourceFile, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class AbstractTestZip64InSmallFiles: # These tests test the ZIP64 functionality without using large files, @@ -1270,6 +1275,10 @@ class LzmaTestZip64InSmallFiles(AbstractTestZip64InSmallFiles, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestZip64InSmallFiles(AbstractTestZip64InSmallFiles, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class AbstractWriterTests: @@ -1339,6 +1348,9 @@ class Bzip2WriterTests(AbstractWriterTests, unittest.TestCase): class LzmaWriterTests(AbstractWriterTests, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdWriterTests(AbstractWriterTests, unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class PyZipFileTests(unittest.TestCase): def assertCompiledIn(self, name, namelist): @@ -2669,6 +2681,17 @@ class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase): b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00' b'\x00>\x00\x00\x00\x00\x00') +@requires_zstd() +class ZstdBadCrcTests(AbstractBadCrcTests, unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD + zip_with_bad_crc = ( + b'PK\x03\x04?\x00\x00\x00]\x00\x00\x00!\x00V\xb1\x17J\x14\x00' + b'\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00afile(\xb5/\xfd\x00' + b'XY\x00\x00Hello WorldPK\x01\x02?\x03?\x00\x00\x00]\x00\x00\x00' + b'!\x00V\xb0\x17J\x14\x00\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00afilePK' + b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x007\x00\x00\x00' + b'\x00\x00') class DecryptionTests(unittest.TestCase): """Check that ZIP decryption works. Since the library does not @@ -2896,6 +2919,10 @@ class LzmaTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD # Provide the tell() method but not seek() class Tellable: diff --git a/Lib/test/test_zstd/__init__.py b/Lib/test/test_zstd/__init__.py new file mode 100644 index 00000000000000..4b16ecc31156a5 --- /dev/null +++ b/Lib/test/test_zstd/__init__.py @@ -0,0 +1,5 @@ +import os +from test.support import load_package_tests + +def load_tests(*args): + return load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_zstd/__main__.py b/Lib/test/test_zstd/__main__.py new file mode 100644 index 00000000000000..e25ac946edffe4 --- /dev/null +++ b/Lib/test/test_zstd/__main__.py @@ -0,0 +1,7 @@ +import unittest + +from . import load_tests # noqa: F401 + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py new file mode 100644 index 00000000000000..f7de5decf59362 --- /dev/null +++ b/Lib/test/test_zstd/test_core.py @@ -0,0 +1,2693 @@ +import array +import gc +import io +import pathlib +import pickle +import random +import builtins +import re +import os +import unittest +import tempfile +import threading + +from compression._common import _streams + +from test.support.import_helper import import_module +from test.support import threading_helper +from test.support import _1M +from test.support import Py_GIL_DISABLED + +zstd = import_module("compression.zstd") +_zstd = import_module("_zstd") +from compression.zstd import ( + zstdfile, + compress, + decompress, + ZstdCompressor, + ZstdDecompressor, + ZstdDict, + ZstdError, + zstd_version, + zstd_version_info, + compressionLevel_values, + get_frame_info, + get_frame_size, + finalize_dict, + train_dict, + CParameter, + DParameter, + Strategy, + ZstdFile, + zstd_support_multithread, +) +from compression.zstd.zstdfile import open + +_1K = 1024 +_130_1K = 130 * _1K +DICT_SIZE1 = 3*_1K + +DAT_130K_D = None +DAT_130K_C = None + +DECOMPRESSED_DAT = None +COMPRESSED_DAT = None + +DECOMPRESSED_100_PLUS_32KB = None +COMPRESSED_100_PLUS_32KB = None + +SKIPPABLE_FRAME = None + +THIS_FILE_BYTES = None +THIS_FILE_STR = None +COMPRESSED_THIS_FILE = None + +COMPRESSED_BOGUS = None + +SAMPLES = None + +TRAINED_DICT = None + +KB = 1024 +MB = 1024*1024 + +def setUpModule(): + # uncompressed size 130KB, more than a zstd block. + # with a frame epilogue, 4 bytes checksum. + global DAT_130K_D + DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*1024)]) + + global DAT_130K_C + DAT_130K_C = compress(DAT_130K_D, options={CParameter.checksumFlag:1}) + + global DECOMPRESSED_DAT + DECOMPRESSED_DAT = b'abcdefg123456' * 1000 + + global COMPRESSED_DAT + COMPRESSED_DAT = compress(DECOMPRESSED_DAT) + + global DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*1024) + + global COMPRESSED_100_PLUS_32KB + COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB) + + global SKIPPABLE_FRAME + SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \ + (32*1024).to_bytes(4, byteorder='little') + \ + b'a' * (32*1024) + + global THIS_FILE_BYTES, THIS_FILE_STR + with builtins.open(os.path.abspath(__file__), 'rb') as f: + THIS_FILE_BYTES = f.read() + THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES) + THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8') + + global COMPRESSED_THIS_FILE + COMPRESSED_THIS_FILE = compress(THIS_FILE_BYTES) + + global COMPRESSED_BOGUS + COMPRESSED_BOGUS = DECOMPRESSED_DAT + + # dict data + words = [b'red', b'green', b'yellow', b'black', b'withe', b'blue', + b'lilac', b'purple', b'navy', b'glod', b'silver', b'olive', + b'dog', b'cat', b'tiger', b'lion', b'fish', b'bird'] + lst = [] + for i in range(300): + sample = [b'%s = %d' % (random.choice(words), random.randrange(100)) + for j in range(20)] + sample = b'\n'.join(sample) + + lst.append(sample) + global SAMPLES + SAMPLES = lst + assert len(SAMPLES) > 10 + + global TRAINED_DICT + TRAINED_DICT = train_dict(SAMPLES, 3*1024) + assert len(TRAINED_DICT.dict_content) <= 3*1024 + + +class FunctionsTestCase(unittest.TestCase): + + def test_version(self): + s = ".".join((str(i) for i in zstd_version_info)) + self.assertEqual(s, zstd_version) + + def test_compressionLevel_values(self): + self.assertIs(type(compressionLevel_values.default), int) + self.assertIs(type(compressionLevel_values.min), int) + self.assertIs(type(compressionLevel_values.max), int) + self.assertLess(compressionLevel_values.min, compressionLevel_values.max) + + def test_roundtrip_default(self): + raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + dat1 = compress(raw_dat) + dat2 = decompress(dat1) + self.assertEqual(dat2, raw_dat) + + def test_roundtrip_level(self): + raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + _default, minv, maxv = compressionLevel_values + + for level in range(max(-20, minv), maxv + 1): + dat1 = compress(raw_dat, level) + dat2 = decompress(dat1) + self.assertEqual(dat2, raw_dat) + + def test_get_frame_info(self): + # no dict + info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20]) + self.assertEqual(info.decompressed_size, 32 * 1024 + 100) + self.assertEqual(info.dictionary_id, 0) + + # use dict + dat = compress(b"a" * 345, zstd_dict=TRAINED_DICT) + info = get_frame_info(dat) + self.assertEqual(info.decompressed_size, 345) + self.assertEqual(info.dictionary_id, TRAINED_DICT.dict_id) + + with self.assertRaisesRegex(ZstdError, "not less than the frame header"): + get_frame_info(b"aaaaaaaaaaaaaa") + + def test_get_frame_size(self): + size = get_frame_size(COMPRESSED_100_PLUS_32KB) + self.assertEqual(size, len(COMPRESSED_100_PLUS_32KB)) + + with self.assertRaisesRegex(ZstdError, "not less than this complete frame"): + get_frame_size(b"aaaaaaaaaaaaaa") + + def test_decompress_2x130_1K(self): + decompressed_size = get_frame_info(DAT_130K_C).decompressed_size + self.assertEqual(decompressed_size, _130_1K) + + dat = decompress(DAT_130K_C + DAT_130K_C) + self.assertEqual(len(dat), 2 * _130_1K) + +class ClassShapeTestCase(unittest.TestCase): + + def test_ZstdCompressor(self): + # class attributes + ZstdCompressor.CONTINUE + ZstdCompressor.FLUSH_BLOCK + ZstdCompressor.FLUSH_FRAME + + # method & me_1Mer + ZstdCompressor() + ZstdCompressor(12, zstd_dict=TRAINED_DICT) + c = ZstdCompressor(level=2, zstd_dict=TRAINED_DICT) + + c.compress(b"123456") + c.compress(b"123456", ZstdCompressor.CONTINUE) + c.compress(data=b"123456", mode=c.CONTINUE) + + c.flush() + c.flush(ZstdCompressor.FLUSH_BLOCK) + c.flush(mode=c.FLUSH_FRAME) + + c.last_mode + + # decompressor method & me_1Mer + with self.assertRaises(AttributeError): + c.decompress(b"") + with self.assertRaises(AttributeError): + c.at_frame_edge + with self.assertRaises(AttributeError): + c.eof + with self.assertRaises(AttributeError): + c.needs_input + + # read only attribute + with self.assertRaises(AttributeError): + c.last_mode = ZstdCompressor.FLUSH_BLOCK + + # name + self.assertIn(".ZstdCompressor", str(type(c))) + + # doesn't support pickle + with self.assertRaises(TypeError): + pickle.dumps(c) + + # supports subclass + class SubClass(ZstdCompressor): + pass + + def test_Decompressor(self): + # method & me_1Mer + ZstdDecompressor() + ZstdDecompressor(TRAINED_DICT, {}) + d = ZstdDecompressor(zstd_dict=TRAINED_DICT, options={}) + + d.decompress(b"") + d.decompress(b"", 100) + d.decompress(data=b"", max_length=100) + + d.eof + d.needs_input + d.unused_data + + # ZstdCompressor attributes + with self.assertRaises(AttributeError): + d.CONTINUE + with self.assertRaises(AttributeError): + d.FLUSH_BLOCK + with self.assertRaises(AttributeError): + d.FLUSH_FRAME + with self.assertRaises(AttributeError): + d.compress(b"") + with self.assertRaises(AttributeError): + d.flush() + + # read only attributes + with self.assertRaises(AttributeError): + d.eof = True + with self.assertRaises(AttributeError): + d.needs_input = True + with self.assertRaises(AttributeError): + d.unused_data = b"" + + # name + self.assertIn(".ZstdDecompressor", str(type(d))) + + # doesn't support pickle + with self.assertRaises(TypeError): + pickle.dumps(d) + + # supports subclass + class SubClass(ZstdDecompressor): + pass + + def test_ZstdDict(self): + ZstdDict(b"12345678", True) + zd = ZstdDict(b"12345678", is_raw=True) + + self.assertEqual(type(zd.dict_content), bytes) + self.assertEqual(zd.dict_id, 0) + self.assertEqual(zd.as_digested_dict[1], 0) + self.assertEqual(zd.as_undigested_dict[1], 1) + self.assertEqual(zd.as_prefix[1], 2) + + # name + self.assertIn(".ZstdDict", str(type(zd))) + + # doesn't support pickle + with self.assertRaisesRegex(TypeError, r"cannot pickle"): + pickle.dumps(zd) + with self.assertRaisesRegex(TypeError, r"cannot pickle"): + pickle.dumps(zd.as_prefix) + + # supports subclass + class SubClass(ZstdDict): + pass + + def test_Strategy(self): + # class attributes + Strategy.fast + Strategy.dfast + Strategy.greedy + Strategy.lazy + Strategy.lazy2 + Strategy.btlazy2 + Strategy.btopt + Strategy.btultra + Strategy.btultra2 + + def test_CParameter(self): + CParameter.compressionLevel + CParameter.windowLog + CParameter.hashLog + CParameter.chainLog + CParameter.searchLog + CParameter.minMatch + CParameter.targetLength + CParameter.strategy + with self.assertRaises(NotImplementedError): + CParameter.targetCBlockSize + + CParameter.enableLongDistanceMatching + CParameter.ldmHashLog + CParameter.ldmMinMatch + CParameter.ldmBucketSizeLog + CParameter.ldmHashRateLog + + CParameter.contentSizeFlag + CParameter.checksumFlag + CParameter.dictIDFlag + + CParameter.nbWorkers + CParameter.jobSize + CParameter.overlapLog + + t = CParameter.windowLog.bounds() + self.assertEqual(len(t), 2) + self.assertEqual(type(t[0]), int) + self.assertEqual(type(t[1]), int) + + def test_DParameter(self): + DParameter.windowLogMax + + t = DParameter.windowLogMax.bounds() + self.assertEqual(len(t), 2) + self.assertEqual(type(t[0]), int) + self.assertEqual(type(t[1]), int) + + def test_zstderror_pickle(self): + try: + decompress(b"invalid data") + except Exception as e: + s = pickle.dumps(e) + obj = pickle.loads(s) + self.assertEqual(type(obj), ZstdError) + else: + self.assertFalse(True, "unreachable code path") + + def test_ZstdFile_extend(self): + # These classes and variables can be used to extend ZstdFile, + # so pin them down. + self.assertTrue(issubclass(ZstdFile, io.BufferedIOBase)) + self.assertIs(ZstdFile._READER_CLASS, _streams.DecompressReader) + + # mode + self.assertEqual(zstdfile._MODE_CLOSED, 0) + self.assertEqual(zstdfile._MODE_READ, 1) + self.assertEqual(zstdfile._MODE_WRITE, 2) + + +class CompressorTestCase(unittest.TestCase): + + def test_simple_compress_bad_args(self): + # ZstdCompressor + self.assertRaises(TypeError, ZstdCompressor, []) + self.assertRaises(TypeError, ZstdCompressor, level=3.14) + self.assertRaises(TypeError, ZstdCompressor, level="abc") + self.assertRaises(TypeError, ZstdCompressor, options=b"abc") + + self.assertRaises(TypeError, ZstdCompressor, zstd_dict=123) + self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234") + self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4}) + + with self.assertRaises(ValueError): + ZstdCompressor(2**31) + with self.assertRaises(ValueError): + ZstdCompressor(options={2**31: 100}) + + with self.assertRaises(ZstdError): + ZstdCompressor(options={CParameter.windowLog: 100}) + with self.assertRaises(ZstdError): + ZstdCompressor(options={3333: 100}) + + # Method bad arguments + zc = ZstdCompressor() + self.assertRaises(TypeError, zc.compress) + self.assertRaises((TypeError, ValueError), zc.compress, b"foo", b"bar") + self.assertRaises(TypeError, zc.compress, "str") + self.assertRaises((TypeError, ValueError), zc.flush, b"foo") + self.assertRaises(TypeError, zc.flush, b"blah", 1) + + self.assertRaises(ValueError, zc.compress, b'', -1) + self.assertRaises(ValueError, zc.compress, b'', 3) + self.assertRaises(ValueError, zc.flush, zc.CONTINUE) # 0 + self.assertRaises(ValueError, zc.flush, 3) + + zc.compress(b'') + zc.compress(b'', zc.CONTINUE) + zc.compress(b'', zc.FLUSH_BLOCK) + zc.compress(b'', zc.FLUSH_FRAME) + empty = zc.flush() + zc.flush(zc.FLUSH_BLOCK) + zc.flush(zc.FLUSH_FRAME) + + def test_compress_parameters(self): + d = {CParameter.compressionLevel : 10, + + CParameter.windowLog : 12, + CParameter.hashLog : 10, + CParameter.chainLog : 12, + CParameter.searchLog : 12, + CParameter.minMatch : 4, + CParameter.targetLength : 12, + CParameter.strategy : Strategy.lazy, + + CParameter.enableLongDistanceMatching : 1, + CParameter.ldmHashLog : 12, + CParameter.ldmMinMatch : 11, + CParameter.ldmBucketSizeLog : 5, + CParameter.ldmHashRateLog : 12, + + CParameter.contentSizeFlag : 1, + CParameter.checksumFlag : 1, + CParameter.dictIDFlag : 0, + + CParameter.nbWorkers : 2 if zstd_support_multithread else 0, + CParameter.jobSize : 5*_1M if zstd_support_multithread else 0, + CParameter.overlapLog : 9 if zstd_support_multithread else 0, + } + ZstdCompressor(options=d) + + # larger than signed int, ValueError + d1 = d.copy() + d1[CParameter.ldmBucketSizeLog] = 2**31 + self.assertRaises(ValueError, ZstdCompressor, d1) + + # clamp compressionLevel + compress(b'', compressionLevel_values.max+1) + compress(b'', compressionLevel_values.min-1) + + compress(b'', {CParameter.compressionLevel:compressionLevel_values.max+1}) + compress(b'', {CParameter.compressionLevel:compressionLevel_values.min-1}) + + # zstd lib doesn't support MT compression + if not zstd_support_multithread: + with self.assertRaises(ZstdError): + ZstdCompressor({CParameter.nbWorkers:4}) + with self.assertRaises(ZstdError): + ZstdCompressor({CParameter.jobSize:4}) + with self.assertRaises(ZstdError): + ZstdCompressor({CParameter.overlapLog:4}) + + # out of bounds error msg + option = {CParameter.windowLog:100} + with self.assertRaisesRegex(ZstdError, + (r'Error when setting zstd compression parameter "windowLog", ' + r'it should \d+ <= value <= \d+, provided value is 100\. ' + r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')): + compress(b'', option) + + def test_unknown_compression_parameter(self): + KEY = 100001234 + option = {CParameter.compressionLevel: 10, + KEY: 200000000} + pattern = r'Zstd compression parameter.*?"unknown parameter \(key %d\)"' \ + % KEY + with self.assertRaisesRegex(ZstdError, pattern): + ZstdCompressor(option) + + @unittest.skipIf(True,#not zstd_support_multithread, + "zstd build doesn't support multi-threaded compression") + def test_zstd_multithread_compress(self): + size = 40*_1M + b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES)) + + options = {CParameter.compressionLevel : 4, + CParameter.nbWorkers : 2} + + # compress() + dat1 = compress(b, options=options) + dat2 = decompress(dat1) + self.assertEqual(dat2, b) + + # ZstdCompressor + c = ZstdCompressor(options=options) + dat1 = c.compress(b, c.CONTINUE) + dat2 = c.compress(b, c.FLUSH_BLOCK) + dat3 = c.compress(b, c.FLUSH_FRAME) + dat4 = decompress(dat1+dat2+dat3) + self.assertEqual(dat4, b * 3) + + # ZstdFile + with ZstdFile(io.BytesIO(), 'w', + options=options) as f: + f.write(b) + + def test_compress_flushblock(self): + point = len(THIS_FILE_BYTES) // 2 + + c = ZstdCompressor() + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + dat1 = c.compress(THIS_FILE_BYTES[:point]) + self.assertEqual(c.last_mode, c.CONTINUE) + dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK) + self.assertEqual(c.last_mode, c.FLUSH_BLOCK) + dat2 = c.flush() + pattern = r"Compressed data ended before the end-of-stream marker" + with self.assertRaisesRegex(ZstdError, pattern): + decompress(dat1) + + dat3 = decompress(dat1 + dat2) + + self.assertEqual(dat3, THIS_FILE_BYTES) + + def test_compress_flushframe(self): + # test compress & decompress + point = len(THIS_FILE_BYTES) // 2 + + c = ZstdCompressor() + + dat1 = c.compress(THIS_FILE_BYTES[:point]) + self.assertEqual(c.last_mode, c.CONTINUE) + + dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_FRAME) + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + + nt = get_frame_info(dat1) + self.assertEqual(nt.decompressed_size, None) # no content size + + dat2 = decompress(dat1) + + self.assertEqual(dat2, THIS_FILE_BYTES) + + # single .FLUSH_FRAME mode has content size + c = ZstdCompressor() + dat = c.compress(THIS_FILE_BYTES, mode=c.FLUSH_FRAME) + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + + nt = get_frame_info(dat) + self.assertEqual(nt.decompressed_size, len(THIS_FILE_BYTES)) + + def test_compress_empty(self): + # output empty content frame + self.assertNotEqual(compress(b''), b'') + + c = ZstdCompressor() + self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'') + +class DecompressorTestCase(unittest.TestCase): + + def test_simple_decompress_bad_args(self): + # ZstdDecompressor + self.assertRaises(TypeError, ZstdDecompressor, ()) + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=123) + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=b'abc') + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict={1:2, 3:4}) + + self.assertRaises(TypeError, ZstdDecompressor, options=123) + self.assertRaises(TypeError, ZstdDecompressor, options='abc') + self.assertRaises(TypeError, ZstdDecompressor, options=b'abc') + + with self.assertRaises(ValueError): + ZstdDecompressor(options={2**31 : 100}) + + with self.assertRaises(ZstdError): + ZstdDecompressor(options={DParameter.windowLogMax:100}) + with self.assertRaises(ZstdError): + ZstdDecompressor(options={3333 : 100}) + + empty = compress(b'') + lzd = ZstdDecompressor() + self.assertRaises(TypeError, lzd.decompress) + self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar") + self.assertRaises(TypeError, lzd.decompress, "str") + lzd.decompress(empty) + + def test_decompress_parameters(self): + d = {DParameter.windowLogMax : 15} + ZstdDecompressor(options=d) + + # larger than signed int, ValueError + d1 = d.copy() + d1[DParameter.windowLogMax] = 2**31 + self.assertRaises(ValueError, ZstdDecompressor, None, d1) + + # out of bounds error msg + options = {DParameter.windowLogMax:100} + with self.assertRaisesRegex(ZstdError, + (r'Error when setting zstd decompression parameter "windowLogMax", ' + r'it should \d+ <= value <= \d+, provided value is 100\. ' + r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')): + decompress(b'', options=options) + + def test_unknown_decompression_parameter(self): + KEY = 100001234 + options = {DParameter.windowLogMax: DParameter.windowLogMax.bounds()[1], + KEY: 200000000} + pattern = r'Zstd decompression parameter.*?"unknown parameter \(key %d\)"' \ + % KEY + with self.assertRaisesRegex(ZstdError, pattern): + ZstdDecompressor(options=options) + + def test_decompress_epilogue_flags(self): + # DAT_130K_C has a 4 bytes checksum at frame epilogue + + # full unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + with self.assertRaises(EOFError): + dat = d.decompress(b'') + + # full limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C, _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + with self.assertRaises(EOFError): + dat = d.decompress(b'', 0) + + # [:-4] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-4] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + # [:-3] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-3]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-3] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-3], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + # [:-1] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-1]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-1] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-1], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + def test_decompressor_arg(self): + zd = ZstdDict(b'12345678', True) + + with self.assertRaises(TypeError): + d = ZstdDecompressor(zstd_dict={}) + + with self.assertRaises(TypeError): + d = ZstdDecompressor(options=zd) + + ZstdDecompressor() + ZstdDecompressor(zd, {}) + ZstdDecompressor(zstd_dict=zd, options={DParameter.windowLogMax:25}) + + def test_decompressor_1(self): + # empty + d = ZstdDecompressor() + dat = d.decompress(b'') + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + + # 130_1K full + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + + # 130_1K full, limit output + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C, _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + + # 130_1K, without 4 bytes checksum + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4]) + + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + + # above, limit output + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4], _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + + # full, unused_data + TRAIL = b'89234893abcd' + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C + TRAIL, _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, TRAIL) + + def test_decompressor_chunks_read_300(self): + TRAIL = b'89234893abcd' + DAT = DAT_130K_C + TRAIL + d = ZstdDecompressor() + + bi = io.BytesIO(DAT) + lst = [] + while True: + if d.needs_input: + dat = bi.read(300) + if not dat: + break + else: + raise Exception('should not get here') + + ret = d.decompress(dat) + lst.append(ret) + if d.eof: + break + + ret = b''.join(lst) + + self.assertEqual(len(ret), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data + bi.read(), TRAIL) + + def test_decompressor_chunks_read_3(self): + TRAIL = b'89234893' + DAT = DAT_130K_C + TRAIL + d = ZstdDecompressor() + + bi = io.BytesIO(DAT) + lst = [] + while True: + if d.needs_input: + dat = bi.read(3) + if not dat: + break + else: + dat = b'' + + ret = d.decompress(dat, 1) + lst.append(ret) + if d.eof: + break + + ret = b''.join(lst) + + self.assertEqual(len(ret), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data + bi.read(), TRAIL) + + + def test_decompress_empty(self): + with self.assertRaises(ZstdError): + decompress(b'') + + d = ZstdDecompressor() + self.assertEqual(d.decompress(b''), b'') + self.assertFalse(d.eof) + + def test_decompress_empty_content_frame(self): + DAT = compress(b'') + # decompress + self.assertGreaterEqual(len(DAT), 4) + self.assertEqual(decompress(DAT), b'') + + with self.assertRaises(ZstdError): + decompress(DAT[:-1]) + + # ZstdDecompressor + d = ZstdDecompressor() + dat = d.decompress(DAT) + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + d = ZstdDecompressor() + dat = d.decompress(DAT[:-1]) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + +class DecompressorFlagsTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + options = {CParameter.checksumFlag:1} + c = ZstdCompressor(options) + + cls.DECOMPRESSED_42 = b'a'*42 + cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME) + + cls.DECOMPRESSED_60 = b'a'*60 + cls.FRAME_60 = c.compress(cls.DECOMPRESSED_60, c.FLUSH_FRAME) + + cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60 + cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60 + + cls._130_1K = 130*1024 + + c = ZstdCompressor() + cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush() + cls.UNKNOWN_FRAME_60 = c.compress(cls.DECOMPRESSED_60) + c.flush() + cls.UNKNOWN_FRAME_42_60 = cls.UNKNOWN_FRAME_42 + cls.UNKNOWN_FRAME_60 + + cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|' + + def test_function_decompress(self): + + self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*1024) + + # 1 frame + self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42), self.DECOMPRESSED_42) + + pattern = r"Compressed data ended before the end-of-stream marker" + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:1]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:-1]) + + # 2 frames + self.assertEqual(decompress(self.FRAME_42_60), self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60), self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.FRAME_42 + self.UNKNOWN_FRAME_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42 + self.FRAME_60), + self.DECOMPRESSED_42_60) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42_60[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.UNKNOWN_FRAME_42_60[:-1]) + + # 130_1K + self.assertEqual(decompress(DAT_130K_C), DAT_130K_D) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(DAT_130K_C[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(DAT_130K_C[:-1]) + + # Unknown frame descriptor + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(b'aaaaaaaaa') + + self.assertEqual( + decompress(self.FRAME_42 + b'aaaaaaaaa'), + self.DECOMPRESSED_42 + ) + + self.assertEqual( + decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa'), + self.DECOMPRESSED_42_60 + ) + + # doesn't match checksum + checksum = DAT_130K_C[-4:] + if checksum[0] == 255: + wrong_checksum = bytes([254]) + checksum[1:] + else: + wrong_checksum = bytes([checksum[0]+1]) + checksum[1:] + + dat = DAT_130K_C[:-4] + wrong_checksum + + with self.assertRaisesRegex(ZstdError, "doesn't match checksum"): + decompress(dat) + + def test_function_skippable(self): + self.assertEqual(decompress(SKIPPABLE_FRAME), b'') + self.assertEqual(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME), b'') + + # 1 frame + 2 skippable + self.assertEqual(len(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + DAT_130K_C)), + self._130_1K) + + self.assertEqual(len(decompress(DAT_130K_C + SKIPPABLE_FRAME + SKIPPABLE_FRAME)), + self._130_1K) + + self.assertEqual(len(decompress(SKIPPABLE_FRAME + DAT_130K_C + SKIPPABLE_FRAME)), + self._130_1K) + + # unknown size + self.assertEqual(decompress(SKIPPABLE_FRAME + self.UNKNOWN_FRAME_60), + self.DECOMPRESSED_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_60 + SKIPPABLE_FRAME), + self.DECOMPRESSED_60) + + # 2 frames + 1 skippable + self.assertEqual(decompress(self.FRAME_42 + SKIPPABLE_FRAME + self.FRAME_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(SKIPPABLE_FRAME + self.FRAME_42_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60 + SKIPPABLE_FRAME), + self.DECOMPRESSED_42_60) + + # incomplete + with self.assertRaises(ZstdError): + decompress(SKIPPABLE_FRAME[:1]) + + with self.assertRaises(ZstdError): + decompress(SKIPPABLE_FRAME[:-1]) + + with self.assertRaises(ZstdError): + decompress(self.FRAME_42 + SKIPPABLE_FRAME[:-1]) + + # Unknown frame descriptor + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME) + + self.assertEqual( + decompress(SKIPPABLE_FRAME + b'aaaaaaaaa'), + b'' + ) + + self.assertEqual( + decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa'), + b'' + ) + + def test_decompressor_1(self): + # empty 1 + d = ZstdDecompressor() + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'', 0) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') + self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'a') + self.assertEqual(d.unused_data, b'a') # twice + + # empty 2 + d = ZstdDecompressor() + + dat = d.decompress(b'', 0) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') + self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'a') + self.assertEqual(d.unused_data, b'a') # twice + + # 1 frame + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_42) + + self.assertEqual(dat, self.DECOMPRESSED_42) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # 1 frame, trail + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_42 + self.TRAIL) + + self.assertEqual(dat, self.DECOMPRESSED_42) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + # 1 frame, 32_1K + temp = compress(b'a'*(32*1024)) + d = ZstdDecompressor() + dat = d.decompress(temp, 32*1024) + + self.assertEqual(dat, b'a'*(32*1024)) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # 1 frame, 32_1K+100, trail + d = ZstdDecompressor() + dat = d.decompress(COMPRESSED_100_PLUS_32KB+self.TRAIL, 100) # 100 bytes + + self.assertEqual(len(dat), 100) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + + dat = d.decompress(b'') # 32_1K + + self.assertEqual(len(dat), 32*1024) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # incomplete 1 + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_60[:1]) + + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete 2 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-4]) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete 3 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-1]) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + + # incomplete 4 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-4], 60) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # Unknown frame descriptor + d = ZstdDecompressor() + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + d.decompress(b'aaaaaaaaa') + + def test_decompressor_skippable(self): + # 1 skippable + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # 1 skippable, max_length=0 + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME, 0) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # 1 skippable, trail + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME + self.TRAIL) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + # incomplete + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME[:-1]) + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME[:-1], 0) + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + + +class ZstdDictTestCase(unittest.TestCase): + + def test_is_raw(self): + # content < 8 + b = b'1234567' + with self.assertRaises(ValueError): + ZstdDict(b) + + # content == 8 + b = b'12345678' + zd = ZstdDict(b, is_raw=True) + self.assertEqual(zd.dict_id, 0) + + temp = compress(b'aaa12345678', level=3, zstd_dict=zd) + self.assertEqual(b'aaa12345678', decompress(temp, zd)) + + # is_raw == False + b = b'12345678abcd' + with self.assertRaises(ValueError): + ZstdDict(b) + + # read only attributes + with self.assertRaises(AttributeError): + zd.dict_content = b + + with self.assertRaises(AttributeError): + zd.dict_id = 10000 + + # ZstdDict arguments + zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=False) + self.assertNotEqual(zd.dict_id, 0) + + zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=True) + self.assertNotEqual(zd.dict_id, 0) # note this assertion + + with self.assertRaises(TypeError): + ZstdDict("12345678abcdef", is_raw=True) + with self.assertRaises(TypeError): + ZstdDict(TRAINED_DICT) + + # invalid parameter + with self.assertRaises(TypeError): + ZstdDict(desk333=345) + + def test_invalid_dict(self): + DICT_MAGIC = 0xEC30A437.to_bytes(4, byteorder='little') + dict_content = DICT_MAGIC + b'abcdefghighlmnopqrstuvwxyz' + + # corrupted + zd = ZstdDict(dict_content, is_raw=False) + with self.assertRaisesRegex(ZstdError, r'ZSTD_CDict.*?corrupted'): + ZstdCompressor(zstd_dict=zd.as_digested_dict) + with self.assertRaisesRegex(ZstdError, r'ZSTD_DDict.*?corrupted'): + ZstdDecompressor(zd) + + # wrong type + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, b'123')) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 1, 2)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, -1)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 3)) + + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor(zstd_dict=(zd, b'123')) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor((zd, 1, 2)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor((zd, -1)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor((zd, 3)) + + def test_train_dict(self): + + + TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1) + ZstdDict(TRAINED_DICT.dict_content, False) + + self.assertNotEqual(TRAINED_DICT.dict_id, 0) + self.assertGreater(len(TRAINED_DICT.dict_content), 0) + self.assertLessEqual(len(TRAINED_DICT.dict_content), DICT_SIZE1) + self.assertTrue(re.match(r'^$', str(TRAINED_DICT))) + + # compress/decompress + c = ZstdCompressor(zstd_dict=TRAINED_DICT) + for sample in SAMPLES: + dat1 = compress(sample, zstd_dict=TRAINED_DICT) + dat2 = decompress(dat1, TRAINED_DICT) + self.assertEqual(sample, dat2) + + dat1 = c.compress(sample) + dat1 += c.flush() + dat2 = decompress(dat1, TRAINED_DICT) + self.assertEqual(sample, dat2) + + def test_finalize_dict(self): + if zstd_version_info < (1, 4, 5): + return + + DICT_SIZE2 = 200*1024 + C_LEVEL = 6 + + try: + dic2 = finalize_dict(TRAINED_DICT, SAMPLES, DICT_SIZE2, C_LEVEL) + except NotImplementedError: + # < v1.4.5 at compile-time, >= v.1.4.5 at run-time + return + + self.assertNotEqual(dic2.dict_id, 0) + self.assertGreater(len(dic2.dict_content), 0) + self.assertLessEqual(len(dic2.dict_content), DICT_SIZE2) + + # compress/decompress + c = ZstdCompressor(C_LEVEL, zstd_dict=dic2) + for sample in SAMPLES: + dat1 = compress(sample, C_LEVEL, zstd_dict=dic2) + dat2 = decompress(dat1, dic2) + self.assertEqual(sample, dat2) + + dat1 = c.compress(sample) + dat1 += c.flush() + dat2 = decompress(dat1, dic2) + self.assertEqual(sample, dat2) + + # dict mismatch + self.assertNotEqual(TRAINED_DICT.dict_id, dic2.dict_id) + + dat1 = compress(SAMPLES[0], zstd_dict=TRAINED_DICT) + with self.assertRaises(ZstdError): + decompress(dat1, dic2) + + def test_train_dict_arguments(self): + with self.assertRaises(ValueError): + train_dict([], 100*_1K) + + with self.assertRaises(ValueError): + train_dict(SAMPLES, -100) + + with self.assertRaises(ValueError): + train_dict(SAMPLES, 0) + + def test_finalize_dict_arguments(self): + if zstd_version_info < (1, 4, 5): + with self.assertRaises(NotImplementedError): + finalize_dict({1:2}, [b'aaa', b'bbb'], 100*_1K, 2) + return + + try: + finalize_dict(TRAINED_DICT, SAMPLES, 1*_1M, 2) + except NotImplementedError: + # < v1.4.5 at compile-time, >= v.1.4.5 at run-time + return + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, [], 100*_1K, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, SAMPLES, -100, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, SAMPLES, 0, 2) + + def test_train_dict_c(self): + # argument wrong type + with self.assertRaises(TypeError): + _zstd._train_dict({}, [], 100) + with self.assertRaises(TypeError): + _zstd._train_dict(b'', 99, 100) + with self.assertRaises(TypeError): + _zstd._train_dict(b'', [], 100.1) + + # size > size_t + with self.assertRaises(ValueError): + _zstd._train_dict(b'', [2**64+1], 100) + + # dict_size <= 0 + with self.assertRaises(ValueError): + _zstd._train_dict(b'', [], 0) + + def test_finalize_dict_c(self): + if zstd_version_info < (1, 4, 5): + with self.assertRaises(NotImplementedError): + _zstd._finalize_dict(1, 2, 3, 4, 5) + return + + try: + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'123', [3,], 1*_1M, 5) + except NotImplementedError: + # < v1.4.5 at compile-time, >= v.1.4.5 at run-time + return + + # argument wrong type + with self.assertRaises(TypeError): + _zstd._finalize_dict({}, b'', [], 100, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, {}, [], 100, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [], 100.1, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5.1) + + # size > size_t + with self.assertRaises(ValueError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [2**64+1], 100, 5) + + # dict_size <= 0 + with self.assertRaises(ValueError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [], 0, 5) + + def test_train_buffer_protocol_samples(self): + def _nbytes(dat): + if isinstance(dat, (bytes, bytearray)): + return len(dat) + return memoryview(dat).nbytes + + # prepare samples + chunk_lst = [] + wrong_size_lst = [] + correct_size_lst = [] + for _ in range(300): + arr = array.array('Q', [random.randint(0, 20) for i in range(20)]) + chunk_lst.append(arr) + correct_size_lst.append(_nbytes(arr)) + wrong_size_lst.append(len(arr)) + concatenation = b''.join(chunk_lst) + + # wrong size list + with self.assertRaisesRegex(ValueError, + "The samples size list doesn't match the concatenation's size"): + _zstd._train_dict(concatenation, wrong_size_lst, 100*1024) + + # correct size list + _zstd._train_dict(concatenation, correct_size_lst, 3*1024) + + # test _finalize_dict + if zstd_version_info < (1, 4, 5): + return + + # wrong size list + with self.assertRaisesRegex(ValueError, + "The samples size list doesn't match the concatenation's size"): + _zstd._finalize_dict(TRAINED_DICT.dict_content, + concatenation, wrong_size_lst, 300*1024, 5) + + # correct size list + _zstd._finalize_dict(TRAINED_DICT.dict_content, + concatenation, correct_size_lst, 300*1024, 5) + + def test_as_prefix(self): + # V1 + V1 = THIS_FILE_BYTES + zd = ZstdDict(V1, True) + + # V2 + mid = len(V1) // 2 + V2 = V1[:mid] + \ + (b'a' if V1[mid] != int.from_bytes(b'a') else b'b') + \ + V1[mid+1:] + + # compress + dat = compress(V2, zstd_dict=zd.as_prefix) + self.assertEqual(get_frame_info(dat).dictionary_id, 0) + + # decompress + self.assertEqual(decompress(dat, zd.as_prefix), V2) + + # use wrong prefix + zd2 = ZstdDict(SAMPLES[0], True) + try: + decompressed = decompress(dat, zd2.as_prefix) + except ZstdError: # expected + pass + else: + self.assertNotEqual(decompressed, V2) + + # read only attribute + with self.assertRaises(AttributeError): + zd.as_prefix = b'1234' + + def test_as_digested_dict(self): + zd = TRAINED_DICT + + # test .as_digested_dict + dat = compress(SAMPLES[0], zstd_dict=zd.as_digested_dict) + self.assertEqual(decompress(dat, zd.as_digested_dict), SAMPLES[0]) + with self.assertRaises(AttributeError): + zd.as_digested_dict = b'1234' + + # test .as_undigested_dict + dat = compress(SAMPLES[0], zstd_dict=zd.as_undigested_dict) + self.assertEqual(decompress(dat, zd.as_undigested_dict), SAMPLES[0]) + with self.assertRaises(AttributeError): + zd.as_undigested_dict = b'1234' + + def test_advanced_compression_parameters(self): + options = {CParameter.compressionLevel: 6, + CParameter.windowLog: 20, + CParameter.enableLongDistanceMatching: 1} + + # automatically select + dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT) + self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) + + # explicitly select + dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT.as_digested_dict) + self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) + + def test_len(self): + self.assertEqual(len(TRAINED_DICT), len(TRAINED_DICT.dict_content)) + self.assertIn(str(len(TRAINED_DICT)), str(TRAINED_DICT)) + +class FileTestCase(unittest.TestCase): + def setUp(self): + self.DECOMPRESSED_42 = b'a'*42 + self.FRAME_42 = compress(self.DECOMPRESSED_42) + + def test_init(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + pass + with ZstdFile(io.BytesIO(), "w") as f: + pass + with ZstdFile(io.BytesIO(), "x") as f: + pass + with ZstdFile(io.BytesIO(), "a") as f: + pass + + with ZstdFile(io.BytesIO(), "w", level=12) as f: + pass + with ZstdFile(io.BytesIO(), "w", options={CParameter.checksumFlag:1}) as f: + pass + with ZstdFile(io.BytesIO(), "w", options={}) as f: + pass + with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f: + pass + + with ZstdFile(io.BytesIO(), "r", options={DParameter.windowLogMax:25}) as f: + pass + with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f: + pass + + def test_init_with_PathLike_filename(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + with ZstdFile(filename, "a") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(filename) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + with ZstdFile(filename, "a") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(filename) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 2) + + os.remove(filename) + + def test_init_with_filename(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + with ZstdFile(filename) as f: + pass + with ZstdFile(filename, "w") as f: + pass + with ZstdFile(filename, "a") as f: + pass + + os.remove(filename) + + def test_init_mode(self): + bi = io.BytesIO() + + with ZstdFile(bi, "r"): + pass + with ZstdFile(bi, "rb"): + pass + with ZstdFile(bi, "w"): + pass + with ZstdFile(bi, "wb"): + pass + with ZstdFile(bi, "a"): + pass + with ZstdFile(bi, "ab"): + pass + + def test_init_with_x_mode(self): + with tempfile.NamedTemporaryFile() as tmp_f: + filename = pathlib.Path(tmp_f.name) + + for mode in ("x", "xb"): + with ZstdFile(filename, mode): + pass + with self.assertRaises(FileExistsError): + with ZstdFile(filename, mode): + pass + os.remove(filename) + + def test_init_bad_mode(self): + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), (3, "x")) + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "xt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "x+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rx") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wx") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "w+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw") + + with self.assertRaisesRegex(TypeError, r"NOT be CParameter"): + ZstdFile(io.BytesIO(), 'rb', options={CParameter.compressionLevel:5}) + with self.assertRaisesRegex(TypeError, r"NOT be DParameter"): + ZstdFile(io.BytesIO(), 'wb', options={DParameter.windowLogMax:21}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12) + + def test_init_bad_check(self): + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(), "w", level='asd') + # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid. + with self.assertRaises(ZstdError): + ZstdFile(io.BytesIO(), "w", options={999:9999}) + with self.assertRaises(ZstdError): + ZstdFile(io.BytesIO(), "w", options={CParameter.windowLog:99}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33) + + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={DParameter.windowLogMax:2**31}) + + with self.assertRaises(ZstdError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={444:333}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict={1:2}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict=b'dict123456') + + def test_init_close_fp(self): + # get a temp file name + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + tmp_f.write(DAT_130K_C) + filename = tmp_f.name + + with self.assertRaises(ValueError): + ZstdFile(filename, options={'a':'b'}) + + # for PyPy + gc.collect() + + os.remove(filename) + + def test_close(self): + with io.BytesIO(COMPRESSED_100_PLUS_32KB) as src: + f = ZstdFile(src) + f.close() + # ZstdFile.close() should not close the underlying file object. + self.assertFalse(src.closed) + # Try closing an already-closed ZstdFile. + f.close() + self.assertFalse(src.closed) + + # Test with a real file on disk, opened directly by ZstdFile. + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + fp = f._fp + f.close() + # Here, ZstdFile.close() *should* close the underlying file object. + self.assertTrue(fp.closed) + # Try closing an already-closed ZstdFile. + f.close() + + os.remove(filename) + + def test_closed(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertFalse(f.closed) + f.read() + self.assertFalse(f.closed) + finally: + f.close() + self.assertTrue(f.closed) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.closed) + finally: + f.close() + self.assertTrue(f.closed) + + def test_fileno(self): + # 1 + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertRaises(io.UnsupportedOperation, f.fileno) + finally: + f.close() + self.assertRaises(ValueError, f.fileno) + + # 2 + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + try: + self.assertEqual(f.fileno(), f._fp.fileno()) + self.assertIsInstance(f.fileno(), int) + finally: + f.close() + self.assertRaises(ValueError, f.fileno) + + os.remove(filename) + + # 3, no .fileno() method + class C: + def read(self, size=-1): + return b'123' + with ZstdFile(C(), 'rb') as f: + with self.assertRaisesRegex(AttributeError, r'fileno'): + f.fileno() + + def test_seekable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertTrue(f.seekable()) + f.read() + self.assertTrue(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + src = io.BytesIO(COMPRESSED_100_PLUS_32KB) + src.seekable = lambda: False + f = ZstdFile(src) + try: + self.assertFalse(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + def test_readable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertTrue(f.readable()) + f.read() + self.assertTrue(f.readable()) + finally: + f.close() + self.assertRaises(ValueError, f.readable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.readable()) + finally: + f.close() + self.assertRaises(ValueError, f.readable) + + def test_writable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertFalse(f.writable()) + f.read() + self.assertFalse(f.writable()) + finally: + f.close() + self.assertRaises(ValueError, f.writable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertTrue(f.writable()) + finally: + f.close() + self.assertRaises(ValueError, f.writable) + + def test_read_0(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertEqual(f.read(0), b"") + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={DParameter.windowLogMax:20}) as f: + self.assertEqual(f.read(0), b"") + + # empty file + with ZstdFile(io.BytesIO(b'')) as f: + self.assertEqual(f.read(0), b"") + with self.assertRaises(EOFError): + f.read(10) + + with ZstdFile(io.BytesIO(b'')) as f: + with self.assertRaises(EOFError): + f.read(10) + + def test_read_10(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + chunks = [] + while True: + result = f.read(10) + if not result: + break + self.assertLessEqual(len(result), 10) + chunks.append(result) + self.assertEqual(b"".join(chunks), DECOMPRESSED_100_PLUS_32KB) + + def test_read_multistream(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 5) + + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + SKIPPABLE_FRAME)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + COMPRESSED_DAT)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_DAT) + + def test_read_incomplete(self): + with ZstdFile(io.BytesIO(DAT_130K_C[:-200])) as f: + self.assertRaises(EOFError, f.read) + + # Trailing data isn't a valid compressed stream + with ZstdFile(io.BytesIO(self.FRAME_42 + b'12345')) as f: + self.assertEqual(f.read(), self.DECOMPRESSED_42) + + with ZstdFile(io.BytesIO(SKIPPABLE_FRAME + b'12345')) as f: + self.assertEqual(f.read(), b'') + + def test_read_truncated(self): + # Drop stream epilogue: 4 bytes checksum + truncated = DAT_130K_C[:-4] + with ZstdFile(io.BytesIO(truncated)) as f: + self.assertRaises(EOFError, f.read) + + with ZstdFile(io.BytesIO(truncated)) as f: + # this is an important test, make sure it doesn't raise EOFError. + self.assertEqual(f.read(130*1024), DAT_130K_D) + with self.assertRaises(EOFError): + f.read(1) + + # Incomplete header + for i in range(1, 20): + with ZstdFile(io.BytesIO(truncated[:i])) as f: + self.assertRaises(EOFError, f.read, 1) + + def test_read_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_DAT)) + f.close() + self.assertRaises(ValueError, f.read) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.read) + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + self.assertRaises(TypeError, f.read, float()) + + def test_read_bad_data(self): + with ZstdFile(io.BytesIO(COMPRESSED_BOGUS)) as f: + self.assertRaises(ZstdError, f.read) + + def test_read_exception(self): + class C: + def read(self, size=-1): + raise OSError + with ZstdFile(C()) as f: + with self.assertRaises(OSError): + f.read(10) + + def test_read1(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + blocks = [] + while True: + result = f.read1() + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DAT_130K_D) + self.assertEqual(f.read1(), b"") + + def test_read1_0(self): + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + self.assertEqual(f.read1(0), b"") + + def test_read1_10(self): + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + blocks = [] + while True: + result = f.read1(10) + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DECOMPRESSED_DAT) + self.assertEqual(f.read1(), b"") + + def test_read1_multistream(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: + blocks = [] + while True: + result = f.read1() + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DECOMPRESSED_100_PLUS_32KB * 5) + self.assertEqual(f.read1(), b"") + + def test_read1_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.read1) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.read1) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertRaises(TypeError, f.read1, None) + + def test_readinto(self): + arr = array.array("I", range(100)) + self.assertEqual(len(arr), 100) + self.assertEqual(len(arr) * arr.itemsize, 400) + ba = bytearray(300) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + # 0 length output buffer + self.assertEqual(f.readinto(ba[0:0]), 0) + + # use correct length for buffer protocol object + self.assertEqual(f.readinto(arr), 400) + self.assertEqual(arr.tobytes(), DECOMPRESSED_100_PLUS_32KB[:400]) + + # normal readinto + self.assertEqual(f.readinto(ba), 300) + self.assertEqual(ba, DECOMPRESSED_100_PLUS_32KB[400:700]) + + def test_peek(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + result = f.peek() + self.assertGreater(len(result), 0) + self.assertTrue(DAT_130K_D.startswith(result)) + self.assertEqual(f.read(), DAT_130K_D) + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + result = f.peek(10) + self.assertGreater(len(result), 0) + self.assertTrue(DAT_130K_D.startswith(result)) + self.assertEqual(f.read(), DAT_130K_D) + + def test_peek_bad_args(self): + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.peek) + + def test_iterator(self): + with io.BytesIO(THIS_FILE_BYTES) as f: + lines = f.readlines() + compressed = compress(THIS_FILE_BYTES) + + # iter + with ZstdFile(io.BytesIO(compressed)) as f: + self.assertListEqual(list(iter(f)), lines) + + # readline + with ZstdFile(io.BytesIO(compressed)) as f: + for line in lines: + self.assertEqual(f.readline(), line) + self.assertEqual(f.readline(), b'') + self.assertEqual(f.readline(), b'') + + # readlines + with ZstdFile(io.BytesIO(compressed)) as f: + self.assertListEqual(f.readlines(), lines) + + def test_decompress_limited(self): + _ZSTD_DStreamInSize = 128*1024 + 3 + + bomb = compress(b'\0' * int(2e6), level=10) + self.assertLess(len(bomb), _ZSTD_DStreamInSize) + + decomp = ZstdFile(io.BytesIO(bomb)) + self.assertEqual(decomp.read(1), b'\0') + + # BufferedReader uses 128 KiB buffer in __init__.py + max_decomp = 128*1024 + self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, + "Excessive amount of data was decompressed") + + def test_write(self): + raw_data = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.write(raw_data) + + comp = ZstdCompressor() + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + with ZstdFile(dst, "w", level=12) as f: + f.write(raw_data) + + comp = ZstdCompressor(12) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + with ZstdFile(dst, "w", options={CParameter.checksumFlag:1}) as f: + f.write(raw_data) + + comp = ZstdCompressor({CParameter.checksumFlag:1}) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + options = {CParameter.compressionLevel:-5, + CParameter.checksumFlag:1} + with ZstdFile(dst, "w", + options=options) as f: + f.write(raw_data) + + comp = ZstdCompressor(options=options) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + def test_write_empty_frame(self): + # .FLUSH_FRAME generates an empty content frame + c = ZstdCompressor() + self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') + self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') + + # don't generate empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + pass + self.assertEqual(bo.getvalue(), b'') + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_FRAME) + self.assertEqual(bo.getvalue(), b'') + + # if .write(b''), generate empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'') + self.assertNotEqual(bo.getvalue(), b'') + + # has an empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_BLOCK) + self.assertNotEqual(bo.getvalue(), b'') + + def test_write_empty_block(self): + # If no internal data, .FLUSH_BLOCK return b''. + c = ZstdCompressor() + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + self.assertNotEqual(c.compress(b'123', c.FLUSH_BLOCK), + b'') + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + self.assertEqual(c.compress(b''), b'') + self.assertEqual(c.compress(b''), b'') + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + + # mode = .last_mode + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'123') + f.flush(f.FLUSH_BLOCK) + fp_pos = f._fp.tell() + self.assertNotEqual(fp_pos, 0) + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), fp_pos) + + # mode != .last_mode + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), 0) + f.write(b'') + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), 0) + + def test_write_101(self): + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + for start in range(0, len(THIS_FILE_BYTES), 101): + f.write(THIS_FILE_BYTES[start:start+101]) + + comp = ZstdCompressor() + expected = comp.compress(THIS_FILE_BYTES) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + def test_write_append(self): + def comp(data): + comp = ZstdCompressor() + return comp.compress(data) + comp.flush() + + part1 = THIS_FILE_BYTES[:1024] + part2 = THIS_FILE_BYTES[1024:1536] + part3 = THIS_FILE_BYTES[1536:] + expected = b"".join(comp(x) for x in (part1, part2, part3)) + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.write(part1) + with ZstdFile(dst, "a") as f: + f.write(part2) + with ZstdFile(dst, "a") as f: + f.write(part3) + self.assertEqual(dst.getvalue(), expected) + + def test_write_bad_args(self): + f = ZstdFile(io.BytesIO(), "w") + f.close() + self.assertRaises(ValueError, f.write, b"foo") + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r") as f: + self.assertRaises(ValueError, f.write, b"bar") + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(TypeError, f.write, None) + self.assertRaises(TypeError, f.write, "text") + self.assertRaises(TypeError, f.write, 789) + + def test_writelines(self): + def comp(data): + comp = ZstdCompressor() + return comp.compress(data) + comp.flush() + + with io.BytesIO(THIS_FILE_BYTES) as f: + lines = f.readlines() + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.writelines(lines) + expected = comp(THIS_FILE_BYTES) + self.assertEqual(dst.getvalue(), expected) + + def test_seek_forward(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(555) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[555:]) + + def test_seek_forward_across_streams(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: + f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 123) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[123:]) + + def test_seek_forward_relative_to_current(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.read(100) + f.seek(1236, 1) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[1336:]) + + def test_seek_forward_relative_to_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-555, 2) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-555:]) + + def test_seek_backward(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.read(1001) + f.seek(211) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[211:]) + + def test_seek_backward_across_streams(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: + f.read(len(DECOMPRESSED_100_PLUS_32KB) + 333) + f.seek(737) + self.assertEqual(f.read(), + DECOMPRESSED_100_PLUS_32KB[737:] + DECOMPRESSED_100_PLUS_32KB) + + def test_seek_backward_relative_to_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-150, 2) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-150:]) + + def test_seek_past_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 9001) + self.assertEqual(f.tell(), len(DECOMPRESSED_100_PLUS_32KB)) + self.assertEqual(f.read(), b"") + + def test_seek_past_start(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-88) + self.assertEqual(f.tell(), 0) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + def test_seek_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.seek, 0) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.seek, 0) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertRaises(ValueError, f.seek, 0, 3) + # io.BufferedReader raises TypeError instead of ValueError + self.assertRaises((TypeError, ValueError), f.seek, 9, ()) + self.assertRaises(TypeError, f.seek, None) + self.assertRaises(TypeError, f.seek, b"derp") + + def test_seek_not_seekable(self): + class C(io.BytesIO): + def seekable(self): + return False + obj = C(COMPRESSED_100_PLUS_32KB) + with ZstdFile(obj, 'r') as f: + d = f.read(1) + self.assertFalse(f.seekable()) + with self.assertRaisesRegex(io.UnsupportedOperation, + 'File or stream is not seekable'): + f.seek(0) + d += f.read() + self.assertEqual(d, DECOMPRESSED_100_PLUS_32KB) + + def test_tell(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + pos = 0 + while True: + self.assertEqual(f.tell(), pos) + result = f.read(random.randint(171, 189)) + if not result: + break + pos += len(result) + self.assertEqual(f.tell(), len(DAT_130K_D)) + with ZstdFile(io.BytesIO(), "w") as f: + for pos in range(0, len(DAT_130K_D), 143): + self.assertEqual(f.tell(), pos) + f.write(DAT_130K_D[pos:pos+143]) + self.assertEqual(f.tell(), len(DAT_130K_D)) + + def test_tell_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.tell) + + def test_file_dict(self): + # default + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # .as_(un)digested_dict + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_file_prefix(self): + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_UnsupportedOperation(self): + # 1 + with ZstdFile(io.BytesIO(), 'r') as f: + with self.assertRaises(io.UnsupportedOperation): + f.write(b'1234') + + # 2 + class T: + def read(self, size): + return b'a' * size + + with self.assertRaises(AttributeError): # on close + with ZstdFile(T(), 'w') as f: + with self.assertRaises(AttributeError): # on write + f.write(b'1234') + + # 3 + with ZstdFile(io.BytesIO(), 'w') as f: + with self.assertRaises(io.UnsupportedOperation): + f.read(100) + with self.assertRaises(io.UnsupportedOperation): + f.seek(100) + self.assertEqual(f.closed, True) + with self.assertRaises(ValueError): + f.readable() + with self.assertRaises(ValueError): + f.tell() + with self.assertRaises(ValueError): + f.read(100) + + def test_read_readinto_readinto1(self): + lst = [] + with ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE*5)) as f: + while True: + method = random.randint(0, 2) + size = random.randint(0, 300) + + if method == 0: + dat = f.read(size) + if not dat and size: + break + lst.append(dat) + elif method == 1: + ba = bytearray(size) + read_size = f.readinto(ba) + if read_size == 0 and size: + break + lst.append(bytes(ba[:read_size])) + elif method == 2: + ba = bytearray(size) + read_size = f.readinto1(ba) + if read_size == 0 and size: + break + lst.append(bytes(ba[:read_size])) + self.assertEqual(b''.join(lst), THIS_FILE_BYTES*5) + + def test_zstdfile_flush(self): + # closed + f = ZstdFile(io.BytesIO(), 'w') + f.close() + with self.assertRaises(ValueError): + f.flush() + + # read + with ZstdFile(io.BytesIO(), 'r') as f: + # does nothing for read-only stream + f.flush() + + # write + DAT = b'abcd' + bi = io.BytesIO() + with ZstdFile(bi, 'w') as f: + self.assertEqual(f.write(DAT), len(DAT)) + self.assertEqual(f.tell(), len(DAT)) + self.assertEqual(bi.tell(), 0) # not enough for a block + + self.assertEqual(f.flush(), None) + self.assertEqual(f.tell(), len(DAT)) + self.assertGreater(bi.tell(), 0) # flushed + + # write, no .flush() method + class C: + def write(self, b): + return len(b) + with ZstdFile(C(), 'w') as f: + self.assertEqual(f.write(DAT), len(DAT)) + self.assertEqual(f.tell(), len(DAT)) + + self.assertEqual(f.flush(), None) + self.assertEqual(f.tell(), len(DAT)) + + def test_zstdfile_flush_mode(self): + self.assertEqual(ZstdFile.FLUSH_BLOCK, ZstdCompressor.FLUSH_BLOCK) + self.assertEqual(ZstdFile.FLUSH_FRAME, ZstdCompressor.FLUSH_FRAME) + with self.assertRaises(AttributeError): + ZstdFile.CONTINUE + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + # flush block + self.assertEqual(f.write(b'123'), 3) + self.assertIsNone(f.flush(f.FLUSH_BLOCK)) + p1 = bo.tell() + # mode == .last_mode, should return + self.assertIsNone(f.flush()) + p2 = bo.tell() + self.assertEqual(p1, p2) + # flush frame + self.assertEqual(f.write(b'456'), 3) + self.assertIsNone(f.flush(mode=f.FLUSH_FRAME)) + # flush frame + self.assertEqual(f.write(b'789'), 3) + self.assertIsNone(f.flush(f.FLUSH_FRAME)) + p1 = bo.tell() + # mode == .last_mode, should return + self.assertIsNone(f.flush(f.FLUSH_FRAME)) + p2 = bo.tell() + self.assertEqual(p1, p2) + self.assertEqual(decompress(bo.getvalue()), b'123456789') + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'123') + with self.assertRaisesRegex(ValueError, r'\.FLUSH_.*?\.FLUSH_'): + f.flush(ZstdCompressor.CONTINUE) + with self.assertRaises(ValueError): + f.flush(-1) + with self.assertRaises(ValueError): + f.flush(123456) + with self.assertRaises(TypeError): + f.flush(node=ZstdCompressor.CONTINUE) + with self.assertRaises((TypeError, ValueError)): + f.flush('FLUSH_FRAME') + with self.assertRaises(TypeError): + f.flush(b'456', f.FLUSH_BLOCK) + + def test_zstdfile_truncate(self): + with ZstdFile(io.BytesIO(), 'w') as f: + with self.assertRaises(io.UnsupportedOperation): + f.truncate(200) + + def test_zstdfile_iter_issue45475(self): + lines = [l for l in ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE))] + self.assertGreater(len(lines), 0) + + def test_append_new_file(self): + with tempfile.NamedTemporaryFile(delete=True) as tmp_f: + filename = tmp_f.name + + with ZstdFile(filename, 'a') as f: + pass + self.assertTrue(os.path.isfile(filename)) + + os.remove(filename) + +class OpenTestCase(unittest.TestCase): + + def test_binary_modes(self): + with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb") as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + with io.BytesIO() as bio: + with open(bio, "wb") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) + with open(bio, "ab") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB * 2) + + def test_text_modes(self): + # empty input + with self.assertRaises(EOFError): + with open(io.BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader: + for _ in reader: + pass + + # read + uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") + with open(io.BytesIO(COMPRESSED_THIS_FILE), "rt", encoding="utf-8") as f: + self.assertEqual(f.read(), uncompressed) + + with io.BytesIO() as bio: + # write + with open(bio, "wt", encoding="utf-8") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-8") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) + # append + with open(bio, "at", encoding="utf-8") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-8") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed * 2) + + def test_bad_params(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + TESTFN = pathlib.Path(tmp_f.name) + + with self.assertRaises(ValueError): + open(TESTFN, "") + with self.assertRaises(ValueError): + open(TESTFN, "rbt") + with self.assertRaises(ValueError): + open(TESTFN, "rb", encoding="utf-8") + with self.assertRaises(ValueError): + open(TESTFN, "rb", errors="ignore") + with self.assertRaises(ValueError): + open(TESTFN, "rb", newline="\n") + + os.remove(TESTFN) + + def test_option(self): + options = {DParameter.windowLogMax:25} + with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + options = {CParameter.compressionLevel:12} + with io.BytesIO() as bio: + with open(bio, "wb", options=options) as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) + + def test_encoding(self): + uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") + + with io.BytesIO() as bio: + with open(bio, "wt", encoding="utf-16-le") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-16-le") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) + bio.seek(0) + with open(bio, "rt", encoding="utf-16-le") as f: + self.assertEqual(f.read().replace(os.linesep, "\n"), uncompressed) + + def test_encoding_error_handler(self): + with io.BytesIO(compress(b"foo\xffbar")) as bio: + with open(bio, "rt", encoding="ascii", errors="ignore") as f: + self.assertEqual(f.read(), "foobar") + + def test_newline(self): + # Test with explicit newline (universal newline mode disabled). + text = THIS_FILE_STR.replace(os.linesep, "\n") + with io.BytesIO() as bio: + with open(bio, "wt", encoding="utf-8", newline="\n") as f: + f.write(text) + bio.seek(0) + with open(bio, "rt", encoding="utf-8", newline="\r") as f: + self.assertEqual(f.readlines(), [text]) + + def test_x_mode(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + TESTFN = pathlib.Path(tmp_f.name) + + for mode in ("x", "xb", "xt"): + os.remove(TESTFN) + + if mode == "xt": + encoding = "utf-8" + else: + encoding = None + with open(TESTFN, mode, encoding=encoding): + pass + with self.assertRaises(FileExistsError): + with open(TESTFN, mode): + pass + + os.remove(TESTFN) + + def test_open_dict(self): + # default + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # .as_(un)digested_dict + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # invalid dictionary + bi = io.BytesIO() + with self.assertRaisesRegex(TypeError, 'zstd_dict'): + open(bi, 'w', zstd_dict={1:2, 2:3}) + + with self.assertRaisesRegex(TypeError, 'zstd_dict'): + open(bi, 'w', zstd_dict=b'1234567890') + + def test_open_prefix(self): + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_buffer_protocol(self): + # don't use len() for buffer protocol objects + arr = array.array("i", range(1000)) + LENGTH = len(arr) * arr.itemsize + + with open(io.BytesIO(), "wb") as f: + self.assertEqual(f.write(arr), LENGTH) + self.assertEqual(f.tell(), LENGTH) + +class FreeThreadingMethodTests(unittest.TestCase): + + @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_compress_locking(self): + input = b'a'* (16*_1K) + num_threads = 8 + + comp = ZstdCompressor() + parts = [] + for _ in range(num_threads): + res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK) + if res: + parts.append(res) + rest1 = comp.flush() + expected = b''.join(parts) + rest1 + + comp = ZstdCompressor() + output = [] + def run_method(method, input_data, output_data): + res = method(input_data, ZstdCompressor.FLUSH_BLOCK) + if res: + output_data.append(res) + threads = [] + + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(comp.compress, input, output)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + rest2 = comp.flush() + self.assertEqual(rest1, rest2) + actual = b''.join(output) + rest2 + self.assertEqual(expected, actual) + + @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_decompress_locking(self): + input = compress(b'a'* (16*_1K)) + num_threads = 8 + # to ensure we decompress over multiple calls, set maxsize + window_size = _1K * 16//num_threads + + decomp = ZstdDecompressor() + parts = [] + for _ in range(num_threads): + res = decomp.decompress(input, window_size) + if res: + parts.append(res) + expected = b''.join(parts) + + comp = ZstdDecompressor() + output = [] + def run_method(method, input_data, output_data): + res = method(input_data, window_size) + if res: + output_data.append(res) + threads = [] + + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(comp.decompress, input, output)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + actual = b''.join(output) + self.assertEqual(expected, actual) + + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/zipfile/__init__.py b/Lib/zipfile/__init__.py index b7840d0f945a66..f6c85c73a43d93 100644 --- a/Lib/zipfile/__init__.py +++ b/Lib/zipfile/__init__.py @@ -31,6 +31,11 @@ except ImportError: lzma = None +try: + from compression import zstd # We may need its compression method +except ImportError: + zstd = None + __all__ = ["BadZipFile", "BadZipfile", "error", "ZIP_STORED", "ZIP_DEFLATED", "ZIP_BZIP2", "ZIP_LZMA", "is_zipfile", "ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile", @@ -58,12 +63,14 @@ class LargeZipFile(Exception): ZIP_DEFLATED = 8 ZIP_BZIP2 = 12 ZIP_LZMA = 14 +ZIP_ZSTANDARD = 93 # Other ZIP compression methods not supported DEFAULT_VERSION = 20 ZIP64_VERSION = 45 BZIP2_VERSION = 46 LZMA_VERSION = 63 +ZSTANDARD_VERSION = 63 # we recognize (but not necessarily support) all features up to that version MAX_EXTRACT_VERSION = 63 @@ -505,6 +512,8 @@ def FileHeader(self, zip64=None): min_version = max(BZIP2_VERSION, min_version) elif self.compress_type == ZIP_LZMA: min_version = max(LZMA_VERSION, min_version) + elif self.compress_type == ZIP_ZSTANDARD: + min_version = max(ZSTANDARD_VERSION, min_version) self.extract_version = max(min_version, self.extract_version) self.create_version = max(min_version, self.create_version) @@ -766,6 +775,7 @@ def decompress(self, data): 14: 'lzma', 18: 'terse', 19: 'lz77', + 93: 'zstd', 97: 'wavpack', 98: 'ppmd', } @@ -785,6 +795,10 @@ def _check_compression(compression): if not lzma: raise RuntimeError( "Compression requires the (missing) lzma module") + elif compression == ZIP_ZSTANDARD: + if not zstd: + raise RuntimeError( + "Compression requires the (missing) compression.zstd module") else: raise NotImplementedError("That compression method is not supported") @@ -798,9 +812,11 @@ def _get_compressor(compress_type, compresslevel=None): if compresslevel is not None: return bz2.BZ2Compressor(compresslevel) return bz2.BZ2Compressor() - # compresslevel is ignored for ZIP_LZMA + # compresslevel is ignored for ZIP_LZMA and ZIP_ZSTANDARD elif compress_type == ZIP_LZMA: return LZMACompressor() + elif compress_type == ZIP_ZSTANDARD: + return zstd.ZstdCompressor() else: return None @@ -815,6 +831,8 @@ def _get_decompressor(compress_type): return bz2.BZ2Decompressor() elif compress_type == ZIP_LZMA: return LZMADecompressor() + elif compress_type == ZIP_ZSTANDARD: + return zstd.ZstdDecompressor() else: descr = compressor_names.get(compress_type) if descr: diff --git a/Makefile.pre.in b/Makefile.pre.in index 925e0a243c9e96..ebfb2d1982ba4a 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -2503,7 +2503,7 @@ maninstall: altmaninstall XMLLIBSUBDIRS= xml xml/dom xml/etree xml/parsers xml/sax LIBSUBDIRS= asyncio \ collections \ - compression compression/bz2 compression/gzip \ + compression compression/bz2 compression/gzip compression/zstd \ compression/lzma compression/zlib compression/_common \ concurrent concurrent/futures \ csv \ @@ -2671,6 +2671,7 @@ TESTSUBDIRS= idlelib/idle_test \ test/test_zipfile/_path \ test/test_zoneinfo \ test/test_zoneinfo/data \ + test/test_zstd \ test/tkinterdata \ test/tokenizedata \ test/tracedmodules \ @@ -3335,6 +3336,7 @@ MODULE__TESTCAPI_DEPS=$(srcdir)/Modules/_testcapi/parts.h $(srcdir)/Modules/_tes MODULE__TESTLIMITEDCAPI_DEPS=$(srcdir)/Modules/_testlimitedcapi/testcapi_long.h $(srcdir)/Modules/_testlimitedcapi/parts.h $(srcdir)/Modules/_testlimitedcapi/util.h MODULE__TESTINTERNALCAPI_DEPS=$(srcdir)/Modules/_testinternalcapi/parts.h MODULE__SQLITE3_DEPS=$(srcdir)/Modules/_sqlite/connection.h $(srcdir)/Modules/_sqlite/cursor.h $(srcdir)/Modules/_sqlite/microprotocols.h $(srcdir)/Modules/_sqlite/module.h $(srcdir)/Modules/_sqlite/prepare_protocol.h $(srcdir)/Modules/_sqlite/row.h $(srcdir)/Modules/_sqlite/util.h +MODULE__ZSTD_DEPS=$(srcdir)/Modules/_zstd/_zstdmodule.h $(srcdir)/Modules/_zstd/buffer.h CODECS_COMMON_HEADERS=$(srcdir)/Modules/cjkcodecs/multibytecodec.h $(srcdir)/Modules/cjkcodecs/cjkcodecs.h MODULE__CODECS_CN_DEPS=$(srcdir)/Modules/cjkcodecs/mappings_cn.h $(CODECS_COMMON_HEADERS) diff --git a/Modules/Setup b/Modules/Setup index 65c22d48ba0bb7..c9d2b2faa505e2 100644 --- a/Modules/Setup +++ b/Modules/Setup @@ -200,6 +200,7 @@ PYTHONPATH=$(COREPYTHONPATH) #_dbm _dbmmodule.c -lgdbm_compat -DUSE_GDBM_COMPAT #_gdbm _gdbmmodule.c -lgdbm #_lzma _lzmamodule.c -llzma +#_zstd _zstd/_zstdmodule.c -lzstd -I$(srcdir)/Modules/_zstd #_uuid _uuidmodule.c -luuid #zlib zlibmodule.c -lz diff --git a/Modules/Setup.stdlib.in b/Modules/Setup.stdlib.in index 33e60f37d19922..8aae11c68efcbb 100644 --- a/Modules/Setup.stdlib.in +++ b/Modules/Setup.stdlib.in @@ -64,10 +64,11 @@ @MODULE__DECIMAL_TRUE@_decimal _decimal/_decimal.c # compression libs and binascii (optional CRC32 from zlib) -# bindings need -lbz2, -lz, or -llzma, respectively +# bindings need -lbz2, -llzma, -lzstd, or -lz, respectively @MODULE_BINASCII_TRUE@binascii binascii.c @MODULE__BZ2_TRUE@_bz2 _bz2module.c @MODULE__LZMA_TRUE@_lzma _lzmamodule.c +@MODULE__ZSTD_TRUE@_zstd _zstd/_zstdmodule.c _zstd/zdict.c _zstd/compressor.c _zstd/decompressor.c @MODULE_ZLIB_TRUE@zlib zlibmodule.c # dbm/gdbm diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c new file mode 100644 index 00000000000000..18dc13b3fd16f0 --- /dev/null +++ b/Modules/_zstd/_zstdmodule.c @@ -0,0 +1,914 @@ +/* +Low level interface to Meta's zstd library for use in the compression.zstd +Python module. +*/ + +#ifndef Py_BUILD_CORE_BUILTIN +# define Py_BUILD_CORE_MODULE 1 +#endif + +#include "_zstdmodule.h" + +/*[clinic input] +module _zstd + +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=4b5f5587aac15c14]*/ +#include "clinic/_zstdmodule.c.h" + + +/* Format error message and set ZstdError. */ +void +set_zstd_error(const _zstd_state* const state, + error_type type, size_t zstd_ret) +{ + char *msg; + assert(ZSTD_isError(zstd_ret)); + + switch (type) + { + case ERR_DECOMPRESS: + msg = "Unable to decompress zstd data: %s"; + break; + case ERR_COMPRESS: + msg = "Unable to compress zstd data: %s"; + break; + case ERR_SET_PLEDGED_INPUT_SIZE: + msg = "Unable to set pledged uncompressed content size: %s"; + break; + + case ERR_LOAD_D_DICT: + msg = "Unable to load zstd dictionary or prefix for decompression: %s"; + break; + case ERR_LOAD_C_DICT: + msg = "Unable to load zstd dictionary or prefix for compression: %s"; + break; + + case ERR_GET_C_BOUNDS: + msg = "Unable to get zstd compression parameter bounds: %s"; + break; + case ERR_GET_D_BOUNDS: + msg = "Unable to get zstd decompression parameter bounds: %s"; + break; + case ERR_SET_C_LEVEL: + msg = "Unable to set zstd compression level: %s"; + break; + + case ERR_TRAIN_DICT: + msg = "Unable to train zstd dictionary: %s"; + break; + case ERR_FINALIZE_DICT: + msg = "Unable to finalize zstd dictionary: %s"; + break; + + default: + Py_UNREACHABLE(); + } + PyErr_Format(state->ZstdError, msg, ZSTD_getErrorName(zstd_ret)); +} + +typedef struct { + int parameter; + char parameter_name[32]; +} ParameterInfo; + +static const ParameterInfo cp_list[] = +{ + {ZSTD_c_compressionLevel, "compressionLevel"}, + {ZSTD_c_windowLog, "windowLog"}, + {ZSTD_c_hashLog, "hashLog"}, + {ZSTD_c_chainLog, "chainLog"}, + {ZSTD_c_searchLog, "searchLog"}, + {ZSTD_c_minMatch, "minMatch"}, + {ZSTD_c_targetLength, "targetLength"}, + {ZSTD_c_strategy, "strategy"}, + + {ZSTD_c_enableLongDistanceMatching, "enableLongDistanceMatching"}, + {ZSTD_c_ldmHashLog, "ldmHashLog"}, + {ZSTD_c_ldmMinMatch, "ldmMinMatch"}, + {ZSTD_c_ldmBucketSizeLog, "ldmBucketSizeLog"}, + {ZSTD_c_ldmHashRateLog, "ldmHashRateLog"}, + + {ZSTD_c_contentSizeFlag, "contentSizeFlag"}, + {ZSTD_c_checksumFlag, "checksumFlag"}, + {ZSTD_c_dictIDFlag, "dictIDFlag"}, + + {ZSTD_c_nbWorkers, "nbWorkers"}, + {ZSTD_c_jobSize, "jobSize"}, + {ZSTD_c_overlapLog, "overlapLog"} +}; + +static const ParameterInfo dp_list[] = +{ + {ZSTD_d_windowLogMax, "windowLogMax"} +}; + +void +set_parameter_error(const _zstd_state* const state, int is_compress, + int key_v, int value_v) +{ + ParameterInfo const *list; + int list_size; + char const *name; + char *type; + ZSTD_bounds bounds; + int i; + char pos_msg[128]; + + if (is_compress) { + list = cp_list; + list_size = Py_ARRAY_LENGTH(cp_list); + type = "compression"; + } + else { + list = dp_list; + list_size = Py_ARRAY_LENGTH(dp_list); + type = "decompression"; + } + + /* Find parameter's name */ + name = NULL; + for (i = 0; i < list_size; i++) { + if (key_v == (list+i)->parameter) { + name = (list+i)->parameter_name; + break; + } + } + + /* Unknown parameter */ + if (name == NULL) { + PyOS_snprintf(pos_msg, sizeof(pos_msg), + "unknown parameter (key %d)", key_v); + name = pos_msg; + } + + /* Get parameter bounds */ + if (is_compress) { + bounds = ZSTD_cParam_getBounds(key_v); + } + else { + bounds = ZSTD_dParam_getBounds(key_v); + } + if (ZSTD_isError(bounds.error)) { + PyErr_Format(state->ZstdError, + "Zstd %s parameter \"%s\" is invalid. (zstd v%s)", + type, name, ZSTD_versionString()); + return; + } + + /* Error message */ + PyErr_Format(state->ZstdError, + "Error when setting zstd %s parameter \"%s\", it " + "should %d <= value <= %d, provided value is %d. " + "(zstd v%s, %d-bit build)", + type, name, + bounds.lowerBound, bounds.upperBound, value_v, + ZSTD_versionString(), 8*(int)sizeof(Py_ssize_t)); +} + +static inline _zstd_state* +get_zstd_state(PyObject *module) +{ + void *state = PyModule_GetState(module); + assert(state != NULL); + return (_zstd_state *)state; +} + + +/*[clinic input] +_zstd._train_dict + + samples_bytes: PyBytesObject + Concatenation of samples. + samples_size_list: object(subclass_of='&PyList_Type') + List of samples' sizes. + dict_size: Py_ssize_t + The size of the dictionary. + / + +Internal function, train a zstd dictionary on sample data. +[clinic start generated code]*/ + +static PyObject * +_zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, + PyObject *samples_size_list, Py_ssize_t dict_size) +/*[clinic end generated code: output=ee53c34c8f77886b input=b21d092c695a3a81]*/ +{ + // TODO(emmatyping): The preamble and suffix to this function and _finalize_dict + // are pretty similar. We should see if we can refactor them to share that code. + Py_ssize_t chunks_number; + size_t *chunk_sizes = NULL; + PyObject *dst_dict_bytes = NULL; + size_t zstd_ret; + Py_ssize_t sizes_sum; + Py_ssize_t i; + + /* Check arguments */ + if (dict_size <= 0) { + PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number."); + return NULL; + } + + chunks_number = Py_SIZE(samples_size_list); + if ((size_t) chunks_number > UINT32_MAX) { + PyErr_Format(PyExc_ValueError, + "The number of samples should be <= %u.", UINT32_MAX); + return NULL; + } + + /* Prepare chunk_sizes */ + chunk_sizes = PyMem_New(size_t, chunks_number); + if (chunk_sizes == NULL) { + PyErr_NoMemory(); + goto error; + } + + sizes_sum = 0; + for (i = 0; i < chunks_number; i++) { + PyObject *size = PyList_GetItemRef(samples_size_list, i); + chunk_sizes[i] = PyLong_AsSize_t(size); + Py_DECREF(size); + if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) { + PyErr_Format(PyExc_ValueError, + "Items in samples_size_list should be an int " + "object, with a value between 0 and %u.", SIZE_MAX); + goto error; + } + sizes_sum += chunk_sizes[i]; + } + + if (sizes_sum != Py_SIZE(samples_bytes)) { + PyErr_SetString(PyExc_ValueError, + "The samples size list doesn't match the concatenation's size."); + goto error; + } + + /* Allocate dict buffer */ + dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size); + if (dst_dict_bytes == NULL) { + goto error; + } + + /* Train the dictionary */ + char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes); + char *samples_buffer = PyBytes_AS_STRING(samples_bytes); + Py_BEGIN_ALLOW_THREADS + zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size, + samples_buffer, + chunk_sizes, (uint32_t)chunks_number); + Py_END_ALLOW_THREADS + + /* Check zstd dict error */ + if (ZDICT_isError(zstd_ret)) { + _zstd_state* const mod_state = get_zstd_state(module); + set_zstd_error(mod_state, ERR_TRAIN_DICT, zstd_ret); + goto error; + } + + /* Resize dict_buffer */ + if (_PyBytes_Resize(&dst_dict_bytes, zstd_ret) < 0) { + goto error; + } + + goto success; + +error: + Py_CLEAR(dst_dict_bytes); + +success: + PyMem_Free(chunk_sizes); + return dst_dict_bytes; +} + +/*[clinic input] +_zstd._finalize_dict + + custom_dict_bytes: PyBytesObject + Custom dictionary content. + samples_bytes: PyBytesObject + Concatenation of samples. + samples_size_list: object(subclass_of='&PyList_Type') + List of samples' sizes. + dict_size: Py_ssize_t + The size of the dictionary. + compression_level: int + Optimize for a specific zstd compression level, 0 means default. + / + +Internal function, finalize a zstd dictionary. +[clinic start generated code]*/ + +static PyObject * +_zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, + PyBytesObject *samples_bytes, + PyObject *samples_size_list, Py_ssize_t dict_size, + int compression_level) +/*[clinic end generated code: output=9c2a7d8c845cee93 input=08531a803d87c56f]*/ +{ + Py_ssize_t chunks_number; + size_t *chunk_sizes = NULL; + PyObject *dst_dict_bytes = NULL; + size_t zstd_ret; + ZDICT_params_t params; + Py_ssize_t sizes_sum; + Py_ssize_t i; + + /* Check arguments */ + if (dict_size <= 0) { + PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number."); + return NULL; + } + + chunks_number = Py_SIZE(samples_size_list); + if ((size_t) chunks_number > UINT32_MAX) { + PyErr_Format(PyExc_ValueError, + "The number of samples should be <= %u.", UINT32_MAX); + return NULL; + } + + /* Prepare chunk_sizes */ + chunk_sizes = PyMem_New(size_t, chunks_number); + if (chunk_sizes == NULL) { + PyErr_NoMemory(); + goto error; + } + + sizes_sum = 0; + for (i = 0; i < chunks_number; i++) { + PyObject *size = PyList_GET_ITEM(samples_size_list, i); + chunk_sizes[i] = PyLong_AsSize_t(size); + if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) { + PyErr_Format(PyExc_ValueError, + "Items in samples_size_list should be an int " + "object, with a value between 0 and %u.", SIZE_MAX); + goto error; + } + sizes_sum += chunk_sizes[i]; + } + + if (sizes_sum != Py_SIZE(samples_bytes)) { + PyErr_SetString(PyExc_ValueError, + "The samples size list doesn't match the concatenation's size."); + goto error; + } + + /* Allocate dict buffer */ + dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size); + if (dst_dict_bytes == NULL) { + goto error; + } + + /* Parameters */ + + /* Optimize for a specific zstd compression level, 0 means default. */ + params.compressionLevel = compression_level; + /* Write log to stderr, 0 = none. */ + params.notificationLevel = 0; + /* Force dictID value, 0 means auto mode (32-bits random value). */ + params.dictID = 0; + + /* Finalize the dictionary */ + Py_BEGIN_ALLOW_THREADS + zstd_ret = ZDICT_finalizeDictionary( + PyBytes_AS_STRING(dst_dict_bytes), dict_size, + PyBytes_AS_STRING(custom_dict_bytes), Py_SIZE(custom_dict_bytes), + PyBytes_AS_STRING(samples_bytes), chunk_sizes, + (uint32_t)chunks_number, params); + Py_END_ALLOW_THREADS + + /* Check zstd dict error */ + if (ZDICT_isError(zstd_ret)) { + _zstd_state* const mod_state = get_zstd_state(module); + set_zstd_error(mod_state, ERR_FINALIZE_DICT, zstd_ret); + goto error; + } + + /* Resize dict_buffer */ + if (_PyBytes_Resize(&dst_dict_bytes, zstd_ret) < 0) { + goto error; + } + + goto success; + +error: + Py_CLEAR(dst_dict_bytes); + +success: + PyMem_Free(chunk_sizes); + return dst_dict_bytes; +} + + +/*[clinic input] +_zstd._get_param_bounds + + is_compress: bool + True for CParameter, False for DParameter. + parameter: int + The parameter to get bounds. + +Internal function, get CParameter/DParameter bounds. +[clinic start generated code]*/ + +static PyObject * +_zstd__get_param_bounds_impl(PyObject *module, int is_compress, + int parameter) +/*[clinic end generated code: output=b751dc710f89ef55 input=fb21ff96aff65df1]*/ +{ + ZSTD_bounds bound; + if (is_compress) { + bound = ZSTD_cParam_getBounds(parameter); + if (ZSTD_isError(bound.error)) { + _zstd_state* const mod_state = get_zstd_state(module); + set_zstd_error(mod_state, ERR_GET_C_BOUNDS, bound.error); + return NULL; + } + } + else { + bound = ZSTD_dParam_getBounds(parameter); + if (ZSTD_isError(bound.error)) { + _zstd_state* const mod_state = get_zstd_state(module); + set_zstd_error(mod_state, ERR_GET_D_BOUNDS, bound.error); + return NULL; + } + } + + return Py_BuildValue("ii", bound.lowerBound, bound.upperBound); +} + +/*[clinic input] +_zstd.get_frame_size + + frame_buffer: Py_buffer + A bytes-like object, it should start from the beginning of a frame, + and contains at least one complete frame. + +Get the size of a zstd frame, including frame header and 4-byte checksum if it has one. + +It will iterate all blocks' headers within a frame, to accumulate the frame size. +[clinic start generated code]*/ + +static PyObject * +_zstd_get_frame_size_impl(PyObject *module, Py_buffer *frame_buffer) +/*[clinic end generated code: output=a7384c2f8780f442 input=7d3ad24311893bf3]*/ +{ + size_t frame_size; + + frame_size = ZSTD_findFrameCompressedSize(frame_buffer->buf, frame_buffer->len); + if (ZSTD_isError(frame_size)) { + _zstd_state* const mod_state = get_zstd_state(module); + PyErr_Format(mod_state->ZstdError, + "Error when finding the compressed size of a zstd frame. " + "Make sure the frame_buffer argument starts from the " + "beginning of a frame, and its length not less than this " + "complete frame. Zstd error message: %s.", + ZSTD_getErrorName(frame_size)); + return NULL; + } + + return PyLong_FromSize_t(frame_size); +} + +/*[clinic input] +_zstd._get_frame_info + + frame_buffer: Py_buffer + A bytes-like object, containing the header of a zstd frame. + +Internal function, get zstd frame infomation from a frame header. +[clinic start generated code]*/ + +static PyObject * +_zstd__get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer) +/*[clinic end generated code: output=5462855464ecdf81 input=67f1f8e4b7b89c4d]*/ +{ + uint64_t decompressed_size; + uint32_t dict_id; + + /* ZSTD_getFrameContentSize */ + decompressed_size = ZSTD_getFrameContentSize(frame_buffer->buf, + frame_buffer->len); + + /* #define ZSTD_CONTENTSIZE_UNKNOWN (0ULL - 1) + #define ZSTD_CONTENTSIZE_ERROR (0ULL - 2) */ + if (decompressed_size == ZSTD_CONTENTSIZE_ERROR) { + _zstd_state* const mod_state = get_zstd_state(module); + PyErr_SetString(mod_state->ZstdError, + "Error when getting information from the header of " + "a zstd frame. Make sure the frame_buffer argument " + "starts from the beginning of a frame, and its length " + "not less than the frame header (6~18 bytes)."); + return NULL; + } + + /* ZSTD_getDictID_fromFrame */ + dict_id = ZSTD_getDictID_fromFrame(frame_buffer->buf, frame_buffer->len); + + /* Build tuple */ + if (decompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) { + return Py_BuildValue("OI", Py_None, dict_id); + } + return Py_BuildValue("KI", decompressed_size, dict_id); +} + +/*[clinic input] +_zstd._set_parameter_types + + c_parameter_type: object(subclass_of='&PyType_Type') + CParameter IntEnum type object + d_parameter_type: object(subclass_of='&PyType_Type') + DParameter IntEnum type object + +Internal function, set CParameter/DParameter types for validity check. +[clinic start generated code]*/ + +static PyObject * +_zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, + PyObject *d_parameter_type) +/*[clinic end generated code: output=a13d4890ccbd2873 input=3e7d0d37c3a1045a]*/ +{ + _zstd_state* const mod_state = get_zstd_state(module); + + if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) { + PyErr_SetString(PyExc_ValueError, + "The two arguments should be CParameter and " + "DParameter types."); + return NULL; + } + + Py_XDECREF(mod_state->CParameter_type); + Py_INCREF(c_parameter_type); + mod_state->CParameter_type = (PyTypeObject*) c_parameter_type; + + Py_XDECREF(mod_state->DParameter_type); + Py_INCREF(d_parameter_type); + mod_state->DParameter_type = (PyTypeObject*)d_parameter_type; + + Py_RETURN_NONE; +} + +static PyMethodDef _zstd_methods[] = { + _ZSTD__TRAIN_DICT_METHODDEF + _ZSTD__FINALIZE_DICT_METHODDEF + _ZSTD__GET_PARAM_BOUNDS_METHODDEF + _ZSTD_GET_FRAME_SIZE_METHODDEF + _ZSTD__GET_FRAME_INFO_METHODDEF + _ZSTD__SET_PARAMETER_TYPES_METHODDEF + + {0} +}; + + +#define ADD_INT_PREFIX_MACRO(module, macro) \ + do { \ + if (PyModule_AddIntConstant(module, "_" #macro, macro) < 0) { \ + return -1; \ + } \ + } while(0) + +static int +add_parameters(PyObject *module) +{ + /* If add new parameters, please also add to cp_list/dp_list above. */ + + /* Compression parameters */ + ADD_INT_PREFIX_MACRO(module, ZSTD_c_compressionLevel); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_windowLog); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_hashLog); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_chainLog); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_searchLog); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_minMatch); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_targetLength); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_strategy); + + ADD_INT_PREFIX_MACRO(module, ZSTD_c_enableLongDistanceMatching); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_ldmHashLog); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_ldmMinMatch); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_ldmBucketSizeLog); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_ldmHashRateLog); + + ADD_INT_PREFIX_MACRO(module, ZSTD_c_contentSizeFlag); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_checksumFlag); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_dictIDFlag); + + ADD_INT_PREFIX_MACRO(module, ZSTD_c_nbWorkers); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_jobSize); + ADD_INT_PREFIX_MACRO(module, ZSTD_c_overlapLog); + + /* Decompression parameters */ + ADD_INT_PREFIX_MACRO(module, ZSTD_d_windowLogMax); + + /* ZSTD_strategy enum */ + ADD_INT_PREFIX_MACRO(module, ZSTD_fast); + ADD_INT_PREFIX_MACRO(module, ZSTD_dfast); + ADD_INT_PREFIX_MACRO(module, ZSTD_greedy); + ADD_INT_PREFIX_MACRO(module, ZSTD_lazy); + ADD_INT_PREFIX_MACRO(module, ZSTD_lazy2); + ADD_INT_PREFIX_MACRO(module, ZSTD_btlazy2); + ADD_INT_PREFIX_MACRO(module, ZSTD_btopt); + ADD_INT_PREFIX_MACRO(module, ZSTD_btultra); + ADD_INT_PREFIX_MACRO(module, ZSTD_btultra2); + + return 0; +} + +static inline PyObject * +get_zstd_version_info(void) +{ + uint32_t ver = ZSTD_versionNumber(); + uint32_t major, minor, release; + + major = ver / 10000; + minor = (ver / 100) % 100; + release = ver % 100; + + return Py_BuildValue("III", major, minor, release); +} + +static inline int +add_vars_to_module(PyObject *module) +{ + PyObject *obj; + + /* zstd_version, a str. */ + if (PyModule_AddStringConstant(module, "zstd_version", + ZSTD_versionString()) < 0) { + return -1; + } + + /* zstd_version_info, a tuple. */ + obj = get_zstd_version_info(); + if (PyModule_AddObjectRef(module, "zstd_version_info", obj) < 0) { + Py_XDECREF(obj); + return -1; + } + Py_DECREF(obj); + + /* Add zstd parameters */ + if (add_parameters(module) < 0) { + return -1; + } + + /* _compressionLevel_values: (default, min, max) + ZSTD_defaultCLevel() was added in zstd v1.5.0 */ + obj = Py_BuildValue("iii", +#if ZSTD_VERSION_NUMBER < 10500 + ZSTD_CLEVEL_DEFAULT, +#else + ZSTD_defaultCLevel(), +#endif + ZSTD_minCLevel(), + ZSTD_maxCLevel()); + if (PyModule_AddObjectRef(module, + "_compressionLevel_values", + obj) < 0) { + Py_XDECREF(obj); + return -1; + } + Py_DECREF(obj); + + /* _ZSTD_CStreamSizes */ + obj = Py_BuildValue("II", + (uint32_t)ZSTD_CStreamInSize(), + (uint32_t)ZSTD_CStreamOutSize()); + if (PyModule_AddObjectRef(module, "_ZSTD_CStreamSizes", obj) < 0) { + Py_XDECREF(obj); + return -1; + } + Py_DECREF(obj); + + /* _ZSTD_DStreamSizes */ + obj = Py_BuildValue("II", + (uint32_t)ZSTD_DStreamInSize(), + (uint32_t)ZSTD_DStreamOutSize()); + if (PyModule_AddObjectRef(module, "_ZSTD_DStreamSizes", obj) < 0) { + Py_XDECREF(obj); + return -1; + } + Py_DECREF(obj); + + /* _ZSTD_CONFIG */ + obj = Py_BuildValue("isOOO", 8*(int)sizeof(Py_ssize_t), "c", + Py_False, + Py_True, +/* User mremap output buffer */ +#if defined(HAVE_MREMAP) + Py_True +#else + Py_False +#endif + ); + if (PyModule_AddObjectRef(module, "_ZSTD_CONFIG", obj) < 0) { + Py_XDECREF(obj); + return -1; + } + Py_DECREF(obj); + + return 0; +} + +#define ADD_STR_TO_STATE_MACRO(STR) \ + do { \ + mod_state->str_##STR = PyUnicode_FromString(#STR); \ + if (mod_state->str_##STR == NULL) { \ + return -1; \ + } \ + } while(0) + +static inline int +add_type_to_module(PyObject *module, const char *name, + PyType_Spec *type_spec, PyTypeObject **dest) +{ + PyObject *temp = PyType_FromModuleAndSpec(module, type_spec, NULL); + + if (PyModule_AddObjectRef(module, name, temp) < 0) { + Py_XDECREF(temp); + return -1; + } + + *dest = (PyTypeObject*) temp; + + return 0; +} + +static inline int +add_constant_to_type(PyTypeObject *type, const char *name, long value) +{ + PyObject *temp; + + temp = PyLong_FromLong(value); + if (temp == NULL) { + return -1; + } + + int rc = PyObject_SetAttrString((PyObject*) type, name, temp); + Py_DECREF(temp); + return rc; +} + +static int _zstd_exec(PyObject *module) { + _zstd_state* const mod_state = get_zstd_state(module); + + /* Reusable objects & variables */ + mod_state->empty_bytes = PyBytes_FromStringAndSize(NULL, 0); + if (mod_state->empty_bytes == NULL) { + return -1; + } + + mod_state->empty_readonly_memoryview = + PyMemoryView_FromMemory((char*)mod_state, 0, PyBUF_READ); + if (mod_state->empty_readonly_memoryview == NULL) { + return -1; + } + + /* Add str to module state */ + ADD_STR_TO_STATE_MACRO(read); + ADD_STR_TO_STATE_MACRO(readinto); + ADD_STR_TO_STATE_MACRO(write); + ADD_STR_TO_STATE_MACRO(flush); + + mod_state->CParameter_type = NULL; + mod_state->DParameter_type = NULL; + + /* Add variables to module */ + if (add_vars_to_module(module) < 0) { + return -1; + } + + /* ZstdError */ + mod_state->ZstdError = PyErr_NewExceptionWithDoc( + "_zstd.ZstdError", + "Call to the underlying zstd library failed.", + NULL, NULL); + if (mod_state->ZstdError == NULL) { + return -1; + } + + if (PyModule_AddObjectRef(module, "ZstdError", mod_state->ZstdError) < 0) { + Py_DECREF(mod_state->ZstdError); + return -1; + } + + /* ZstdDict */ + if (add_type_to_module(module, + "ZstdDict", + &zstddict_type_spec, + &mod_state->ZstdDict_type) < 0) { + return -1; + } + + // ZstdCompressor + if (add_type_to_module(module, + "ZstdCompressor", + &zstdcompressor_type_spec, + &mod_state->ZstdCompressor_type) < 0) { + return -1; + } + + // Add EndDirective enum to ZstdCompressor + if (add_constant_to_type(mod_state->ZstdCompressor_type, + "CONTINUE", + ZSTD_e_continue) < 0) { + return -1; + } + + if (add_constant_to_type(mod_state->ZstdCompressor_type, + "FLUSH_BLOCK", + ZSTD_e_flush) < 0) { + return -1; + } + + if (add_constant_to_type(mod_state->ZstdCompressor_type, + "FLUSH_FRAME", + ZSTD_e_end) < 0) { + return -1; + } + + // ZstdDecompressor + if (add_type_to_module(module, + "ZstdDecompressor", + &ZstdDecompressor_type_spec, + &mod_state->ZstdDecompressor_type) < 0) { + return -1; + } + + return 0; +} + +static int +_zstd_traverse(PyObject *module, visitproc visit, void *arg) +{ + _zstd_state* const mod_state = get_zstd_state(module); + + Py_VISIT(mod_state->empty_bytes); + Py_VISIT(mod_state->empty_readonly_memoryview); + Py_VISIT(mod_state->str_read); + Py_VISIT(mod_state->str_readinto); + Py_VISIT(mod_state->str_write); + Py_VISIT(mod_state->str_flush); + + Py_VISIT(mod_state->ZstdDict_type); + Py_VISIT(mod_state->ZstdCompressor_type); + + Py_VISIT(mod_state->ZstdDecompressor_type); + + Py_VISIT(mod_state->ZstdError); + + Py_VISIT(mod_state->CParameter_type); + Py_VISIT(mod_state->DParameter_type); + return 0; +} + +static int +_zstd_clear(PyObject *module) +{ + _zstd_state* const mod_state = get_zstd_state(module); + + Py_CLEAR(mod_state->empty_bytes); + Py_CLEAR(mod_state->empty_readonly_memoryview); + Py_CLEAR(mod_state->str_read); + Py_CLEAR(mod_state->str_readinto); + Py_CLEAR(mod_state->str_write); + Py_CLEAR(mod_state->str_flush); + + Py_CLEAR(mod_state->ZstdDict_type); + Py_CLEAR(mod_state->ZstdCompressor_type); + + Py_CLEAR(mod_state->ZstdDecompressor_type); + + Py_CLEAR(mod_state->ZstdError); + + Py_CLEAR(mod_state->CParameter_type); + Py_CLEAR(mod_state->DParameter_type); + return 0; +} + +static void +_zstd_free(void *module) +{ + (void)_zstd_clear((PyObject *)module); +} + +static struct PyModuleDef_Slot _zstd_slots[] = { + {Py_mod_exec, _zstd_exec}, + {Py_mod_gil, Py_MOD_GIL_NOT_USED}, + + {0} +}; + +struct PyModuleDef _zstdmodule = { + PyModuleDef_HEAD_INIT, + .m_name = "_zstd", + .m_size = sizeof(_zstd_state), + .m_slots = _zstd_slots, + .m_methods = _zstd_methods, + .m_traverse = _zstd_traverse, + .m_clear = _zstd_clear, + .m_free = _zstd_free +}; + +PyMODINIT_FUNC +PyInit__zstd(void) +{ + return PyModuleDef_Init(&_zstdmodule); +} diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h new file mode 100644 index 00000000000000..faa9b8fec65870 --- /dev/null +++ b/Modules/_zstd/_zstdmodule.h @@ -0,0 +1,202 @@ +#pragma once +/* +Low level interface to Meta's zstd library for use in the compression.zstd +Python module. +*/ + +/* Declarations shared between different parts of the _zstd module*/ + +#include "Python.h" + +#include "zstd.h" +#include "zdict.h" + + +#define PYTHON_MINIMUM_SUPPORTED_ZSTD_VERSION 10400 + +#if ZSTD_VERSION_NUMBER < PYTHON_MINIMUM_SUPPORTED_ZSTD_VERSION + #error "_zstd module requires zstd v1.4.0+" +#endif + +/* Forward declaration of module state */ +typedef struct _zstd_state _zstd_state; + +/* Forward reference of module def */ +extern PyModuleDef _zstdmodule; + +/* For clinic type calculations */ +static inline _zstd_state * +get_zstd_state_from_type(PyTypeObject *type) { + PyObject *module = PyType_GetModuleByDef(type, &_zstdmodule); + if (module == NULL) { + return NULL; + } + void *state = PyModule_GetState(module); + assert(state != NULL); + return (_zstd_state *)state; +} + +extern PyType_Spec zstddict_type_spec; +extern PyType_Spec zstdcompressor_type_spec; +extern PyType_Spec ZstdDecompressor_type_spec; + +struct _zstd_state { + PyObject *empty_bytes; + PyObject *empty_readonly_memoryview; + PyObject *str_read; + PyObject *str_readinto; + PyObject *str_write; + PyObject *str_flush; + + PyTypeObject *ZstdDict_type; + PyTypeObject *ZstdCompressor_type; + PyTypeObject *ZstdDecompressor_type; + PyObject *ZstdError; + + PyTypeObject *CParameter_type; + PyTypeObject *DParameter_type; +}; + +typedef struct { + PyObject_HEAD + + /* Reusable compress/decompress dictionary, they are created once and + can be shared by multiple threads concurrently, since its usage is + read-only. + c_dicts is a dict, int(compressionLevel):PyCapsule(ZSTD_CDict*) */ + ZSTD_DDict *d_dict; + PyObject *c_dicts; + + /* Content of the dictionary, bytes object. */ + PyObject *dict_content; + /* Dictionary id */ + uint32_t dict_id; + + /* __init__ has been called, 0 or 1. */ + int inited; +} ZstdDict; + +typedef struct { + PyObject_HEAD + + /* Compression context */ + ZSTD_CCtx *cctx; + + /* ZstdDict object in use */ + PyObject *dict; + + /* Last mode, initialized to ZSTD_e_end */ + int last_mode; + + /* (nbWorker >= 1) ? 1 : 0 */ + int use_multithread; + + /* Compression level */ + int compression_level; + + /* __init__ has been called, 0 or 1. */ + int inited; +} ZstdCompressor; + +typedef struct { + PyObject_HEAD + + /* Decompression context */ + ZSTD_DCtx *dctx; + + /* ZstdDict object in use */ + PyObject *dict; + + /* Unconsumed input data */ + char *input_buffer; + size_t input_buffer_size; + size_t in_begin, in_end; + + /* Unused data */ + PyObject *unused_data; + + /* 0 if decompressor has (or may has) unconsumed input data, 0 or 1. */ + char needs_input; + + /* For decompress(), 0 or 1. + 1 when both input and output streams are at a frame edge, means a + frame is completely decoded and fully flushed, or the decompressor + just be initialized. */ + char at_frame_edge; + + /* For ZstdDecompressor, 0 or 1. + 1 means the end of the first frame has been reached. */ + char eof; + + /* Used for fast reset above three variables */ + char _unused_char_for_align; + + /* __init__ has been called, 0 or 1. */ + int inited; +} ZstdDecompressor; + +typedef enum { + TYPE_DECOMPRESSOR, // , ZstdDecompressor class + TYPE_ENDLESS_DECOMPRESSOR, // , decompress() function +} decompress_type; + +typedef enum { + ERR_DECOMPRESS, + ERR_COMPRESS, + ERR_SET_PLEDGED_INPUT_SIZE, + + ERR_LOAD_D_DICT, + ERR_LOAD_C_DICT, + + ERR_GET_C_BOUNDS, + ERR_GET_D_BOUNDS, + ERR_SET_C_LEVEL, + + ERR_TRAIN_DICT, + ERR_FINALIZE_DICT +} error_type; + +typedef enum { + DICT_TYPE_DIGESTED = 0, + DICT_TYPE_UNDIGESTED = 1, + DICT_TYPE_PREFIX = 2 +} dictionary_type; + +static inline int +mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer *out) { + return in->size == in->pos && out->size != out->pos; +} + +/* Format error message and set ZstdError. */ +extern void +set_zstd_error(const _zstd_state* const state, + const error_type type, size_t zstd_ret); + +extern void +set_parameter_error(const _zstd_state* const state, int is_compress, + int key_v, int value_v); + +static const char init_twice_msg[] = "__init__ method is called twice."; + +extern int +_PyZstd_load_c_dict(ZstdCompressor *self, PyObject *dict); + +extern int +_PyZstd_load_d_dict(ZstdDecompressor *self, PyObject *dict); + +extern int +_PyZstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, + const char *arg_name, const char *arg_type); + +extern int +_PyZstd_set_d_parameters(ZstdDecompressor *self, PyObject *options); + +extern PyObject * +decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in, + Py_ssize_t max_length, + Py_ssize_t initial_size, + decompress_type type); + +extern PyObject * +compress_impl(ZstdCompressor *self, Py_buffer *data, + ZSTD_EndDirective end_directive); diff --git a/Modules/_zstd/buffer.h b/Modules/_zstd/buffer.h new file mode 100644 index 00000000000000..319b1214833fcf --- /dev/null +++ b/Modules/_zstd/buffer.h @@ -0,0 +1,104 @@ +/* +Low level interface to Meta's zstd library for use in the compression.zstd +Python module. +*/ + +#include "_zstdmodule.h" +#include "pycore_blocks_output_buffer.h" + +/* Blocks output buffer wrapper code */ + +/* Initialize the buffer, and grow the buffer. + Return 0 on success + Return -1 on failure */ +static inline int +_OutputBuffer_InitAndGrow(_BlocksOutputBuffer *buffer, ZSTD_outBuffer *ob, + Py_ssize_t max_length) +{ + /* Ensure .list was set to NULL */ + assert(buffer->list == NULL); + + Py_ssize_t res = _BlocksOutputBuffer_InitAndGrow(buffer, max_length, &ob->dst); + if (res < 0) { + return -1; + } + ob->size = (size_t) res; + ob->pos = 0; + return 0; +} + +/* Initialize the buffer, with an initial size. + init_size: the initial size. + Return 0 on success + Return -1 on failure */ +static inline int +_OutputBuffer_InitWithSize(_BlocksOutputBuffer *buffer, ZSTD_outBuffer *ob, + Py_ssize_t max_length, + Py_ssize_t init_size) +{ + Py_ssize_t block_size; + + /* Ensure .list was set to NULL */ + assert(buffer->list == NULL); + + /* Get block size */ + if (0 <= max_length && max_length < init_size) { + block_size = max_length; + } + else { + block_size = init_size; + } + + Py_ssize_t res = _BlocksOutputBuffer_InitWithSize(buffer, block_size, &ob->dst); + if (res < 0) { + return -1; + } + // Set max_length, InitWithSize doesn't do this + buffer->max_length = max_length; + ob->size = (size_t) res; + ob->pos = 0; + return 0; +} + +/* Grow the buffer. + Return 0 on success + Return -1 on failure */ +static inline int +_OutputBuffer_Grow(_BlocksOutputBuffer *buffer, ZSTD_outBuffer *ob) +{ + assert(ob->pos == ob->size); + Py_ssize_t res = _BlocksOutputBuffer_Grow(buffer, &ob->dst, 0); + if (res < 0) { + return -1; + } + ob->size = (size_t) res; + ob->pos = 0; + return 0; +} + +/* Finish the buffer. + Return a bytes object on success + Return NULL on failure */ +static inline PyObject * +_OutputBuffer_Finish(_BlocksOutputBuffer *buffer, ZSTD_outBuffer *ob) +{ + return _BlocksOutputBuffer_Finish(buffer, ob->size - ob->pos); +} + +/* Clean up the buffer */ +static inline void +_OutputBuffer_OnError(_BlocksOutputBuffer *buffer) +{ + _BlocksOutputBuffer_OnError(buffer); +} + +/* Whether the output data has reached max_length. +The avail_out must be 0, please check it before calling. */ +static inline int +_OutputBuffer_ReachedMaxLength(_BlocksOutputBuffer *buffer, ZSTD_outBuffer *ob) +{ + /* Ensure (data size == allocated size) */ + assert(ob->pos == ob->size); + + return buffer->allocated == buffer->max_length; +} diff --git a/Modules/_zstd/clinic/_zstdmodule.c.h b/Modules/_zstd/clinic/_zstdmodule.c.h new file mode 100644 index 00000000000000..4b78bded67bca7 --- /dev/null +++ b/Modules/_zstd/clinic/_zstdmodule.c.h @@ -0,0 +1,432 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +#if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) +# include "pycore_gc.h" // PyGC_Head +# include "pycore_runtime.h" // _Py_ID() +#endif +#include "pycore_abstract.h" // _PyNumber_Index() +#include "pycore_modsupport.h" // _PyArg_CheckPositional() + +PyDoc_STRVAR(_zstd__train_dict__doc__, +"_train_dict($module, samples_bytes, samples_size_list, dict_size, /)\n" +"--\n" +"\n" +"Internal function, train a zstd dictionary on sample data.\n" +"\n" +" samples_bytes\n" +" Concatenation of samples.\n" +" samples_size_list\n" +" List of samples\' sizes.\n" +" dict_size\n" +" The size of the dictionary."); + +#define _ZSTD__TRAIN_DICT_METHODDEF \ + {"_train_dict", _PyCFunction_CAST(_zstd__train_dict), METH_FASTCALL, _zstd__train_dict__doc__}, + +static PyObject * +_zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, + PyObject *samples_size_list, Py_ssize_t dict_size); + +static PyObject * +_zstd__train_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyBytesObject *samples_bytes; + PyObject *samples_size_list; + Py_ssize_t dict_size; + + if (!_PyArg_CheckPositional("_train_dict", nargs, 3, 3)) { + goto exit; + } + if (!PyBytes_Check(args[0])) { + _PyArg_BadArgument("_train_dict", "argument 1", "bytes", args[0]); + goto exit; + } + samples_bytes = (PyBytesObject *)args[0]; + if (!PyList_Check(args[1])) { + _PyArg_BadArgument("_train_dict", "argument 2", "list", args[1]); + goto exit; + } + samples_size_list = args[1]; + { + Py_ssize_t ival = -1; + PyObject *iobj = _PyNumber_Index(args[2]); + if (iobj != NULL) { + ival = PyLong_AsSsize_t(iobj); + Py_DECREF(iobj); + } + if (ival == -1 && PyErr_Occurred()) { + goto exit; + } + dict_size = ival; + } + return_value = _zstd__train_dict_impl(module, samples_bytes, samples_size_list, dict_size); + +exit: + return return_value; +} + +PyDoc_STRVAR(_zstd__finalize_dict__doc__, +"_finalize_dict($module, custom_dict_bytes, samples_bytes,\n" +" samples_size_list, dict_size, compression_level, /)\n" +"--\n" +"\n" +"Internal function, finalize a zstd dictionary.\n" +"\n" +" custom_dict_bytes\n" +" Custom dictionary content.\n" +" samples_bytes\n" +" Concatenation of samples.\n" +" samples_size_list\n" +" List of samples\' sizes.\n" +" dict_size\n" +" The size of the dictionary.\n" +" compression_level\n" +" Optimize for a specific zstd compression level, 0 means default."); + +#define _ZSTD__FINALIZE_DICT_METHODDEF \ + {"_finalize_dict", _PyCFunction_CAST(_zstd__finalize_dict), METH_FASTCALL, _zstd__finalize_dict__doc__}, + +static PyObject * +_zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, + PyBytesObject *samples_bytes, + PyObject *samples_size_list, Py_ssize_t dict_size, + int compression_level); + +static PyObject * +_zstd__finalize_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyBytesObject *custom_dict_bytes; + PyBytesObject *samples_bytes; + PyObject *samples_size_list; + Py_ssize_t dict_size; + int compression_level; + + if (!_PyArg_CheckPositional("_finalize_dict", nargs, 5, 5)) { + goto exit; + } + if (!PyBytes_Check(args[0])) { + _PyArg_BadArgument("_finalize_dict", "argument 1", "bytes", args[0]); + goto exit; + } + custom_dict_bytes = (PyBytesObject *)args[0]; + if (!PyBytes_Check(args[1])) { + _PyArg_BadArgument("_finalize_dict", "argument 2", "bytes", args[1]); + goto exit; + } + samples_bytes = (PyBytesObject *)args[1]; + if (!PyList_Check(args[2])) { + _PyArg_BadArgument("_finalize_dict", "argument 3", "list", args[2]); + goto exit; + } + samples_size_list = args[2]; + { + Py_ssize_t ival = -1; + PyObject *iobj = _PyNumber_Index(args[3]); + if (iobj != NULL) { + ival = PyLong_AsSsize_t(iobj); + Py_DECREF(iobj); + } + if (ival == -1 && PyErr_Occurred()) { + goto exit; + } + dict_size = ival; + } + compression_level = PyLong_AsInt(args[4]); + if (compression_level == -1 && PyErr_Occurred()) { + goto exit; + } + return_value = _zstd__finalize_dict_impl(module, custom_dict_bytes, samples_bytes, samples_size_list, dict_size, compression_level); + +exit: + return return_value; +} + +PyDoc_STRVAR(_zstd__get_param_bounds__doc__, +"_get_param_bounds($module, /, is_compress, parameter)\n" +"--\n" +"\n" +"Internal function, get CParameter/DParameter bounds.\n" +"\n" +" is_compress\n" +" True for CParameter, False for DParameter.\n" +" parameter\n" +" The parameter to get bounds."); + +#define _ZSTD__GET_PARAM_BOUNDS_METHODDEF \ + {"_get_param_bounds", _PyCFunction_CAST(_zstd__get_param_bounds), METH_FASTCALL|METH_KEYWORDS, _zstd__get_param_bounds__doc__}, + +static PyObject * +_zstd__get_param_bounds_impl(PyObject *module, int is_compress, + int parameter); + +static PyObject * +_zstd__get_param_bounds(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 2 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(is_compress), &_Py_ID(parameter), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"is_compress", "parameter", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "_get_param_bounds", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[2]; + int is_compress; + int parameter; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, + /*minpos*/ 2, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!args) { + goto exit; + } + is_compress = PyObject_IsTrue(args[0]); + if (is_compress < 0) { + goto exit; + } + parameter = PyLong_AsInt(args[1]); + if (parameter == -1 && PyErr_Occurred()) { + goto exit; + } + return_value = _zstd__get_param_bounds_impl(module, is_compress, parameter); + +exit: + return return_value; +} + +PyDoc_STRVAR(_zstd_get_frame_size__doc__, +"get_frame_size($module, /, frame_buffer)\n" +"--\n" +"\n" +"Get the size of a zstd frame, including frame header and 4-byte checksum if it has one.\n" +"\n" +" frame_buffer\n" +" A bytes-like object, it should start from the beginning of a frame,\n" +" and contains at least one complete frame.\n" +"\n" +"It will iterate all blocks\' headers within a frame, to accumulate the frame size."); + +#define _ZSTD_GET_FRAME_SIZE_METHODDEF \ + {"get_frame_size", _PyCFunction_CAST(_zstd_get_frame_size), METH_FASTCALL|METH_KEYWORDS, _zstd_get_frame_size__doc__}, + +static PyObject * +_zstd_get_frame_size_impl(PyObject *module, Py_buffer *frame_buffer); + +static PyObject * +_zstd_get_frame_size(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 1 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(frame_buffer), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"frame_buffer", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "get_frame_size", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[1]; + Py_buffer frame_buffer = {NULL, NULL}; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, + /*minpos*/ 1, /*maxpos*/ 1, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!args) { + goto exit; + } + if (PyObject_GetBuffer(args[0], &frame_buffer, PyBUF_SIMPLE) != 0) { + goto exit; + } + return_value = _zstd_get_frame_size_impl(module, &frame_buffer); + +exit: + /* Cleanup for frame_buffer */ + if (frame_buffer.obj) { + PyBuffer_Release(&frame_buffer); + } + + return return_value; +} + +PyDoc_STRVAR(_zstd__get_frame_info__doc__, +"_get_frame_info($module, /, frame_buffer)\n" +"--\n" +"\n" +"Internal function, get zstd frame infomation from a frame header.\n" +"\n" +" frame_buffer\n" +" A bytes-like object, containing the header of a zstd frame."); + +#define _ZSTD__GET_FRAME_INFO_METHODDEF \ + {"_get_frame_info", _PyCFunction_CAST(_zstd__get_frame_info), METH_FASTCALL|METH_KEYWORDS, _zstd__get_frame_info__doc__}, + +static PyObject * +_zstd__get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer); + +static PyObject * +_zstd__get_frame_info(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 1 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(frame_buffer), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"frame_buffer", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "_get_frame_info", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[1]; + Py_buffer frame_buffer = {NULL, NULL}; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, + /*minpos*/ 1, /*maxpos*/ 1, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!args) { + goto exit; + } + if (PyObject_GetBuffer(args[0], &frame_buffer, PyBUF_SIMPLE) != 0) { + goto exit; + } + return_value = _zstd__get_frame_info_impl(module, &frame_buffer); + +exit: + /* Cleanup for frame_buffer */ + if (frame_buffer.obj) { + PyBuffer_Release(&frame_buffer); + } + + return return_value; +} + +PyDoc_STRVAR(_zstd__set_parameter_types__doc__, +"_set_parameter_types($module, /, c_parameter_type, d_parameter_type)\n" +"--\n" +"\n" +"Internal function, set CParameter/DParameter types for validity check.\n" +"\n" +" c_parameter_type\n" +" CParameter IntEnum type object\n" +" d_parameter_type\n" +" DParameter IntEnum type object"); + +#define _ZSTD__SET_PARAMETER_TYPES_METHODDEF \ + {"_set_parameter_types", _PyCFunction_CAST(_zstd__set_parameter_types), METH_FASTCALL|METH_KEYWORDS, _zstd__set_parameter_types__doc__}, + +static PyObject * +_zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, + PyObject *d_parameter_type); + +static PyObject * +_zstd__set_parameter_types(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 2 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(c_parameter_type), &_Py_ID(d_parameter_type), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"c_parameter_type", "d_parameter_type", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "_set_parameter_types", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[2]; + PyObject *c_parameter_type; + PyObject *d_parameter_type; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, + /*minpos*/ 2, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!args) { + goto exit; + } + if (!PyObject_TypeCheck(args[0], &PyType_Type)) { + _PyArg_BadArgument("_set_parameter_types", "argument 'c_parameter_type'", (&PyType_Type)->tp_name, args[0]); + goto exit; + } + c_parameter_type = args[0]; + if (!PyObject_TypeCheck(args[1], &PyType_Type)) { + _PyArg_BadArgument("_set_parameter_types", "argument 'd_parameter_type'", (&PyType_Type)->tp_name, args[1]); + goto exit; + } + d_parameter_type = args[1]; + return_value = _zstd__set_parameter_types_impl(module, c_parameter_type, d_parameter_type); + +exit: + return return_value; +} +/*[clinic end generated code: output=077c8ea2b11fb188 input=a9049054013a1b77]*/ diff --git a/Modules/_zstd/clinic/compressor.c.h b/Modules/_zstd/clinic/compressor.c.h new file mode 100644 index 00000000000000..d7909cdf89fcd1 --- /dev/null +++ b/Modules/_zstd/clinic/compressor.c.h @@ -0,0 +1,255 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +#if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) +# include "pycore_gc.h" // PyGC_Head +# include "pycore_runtime.h" // _Py_ID() +#endif +#include "pycore_modsupport.h" // _PyArg_UnpackKeywords() + +PyDoc_STRVAR(_zstd_ZstdCompressor___init____doc__, +"ZstdCompressor(level=None, options=None, zstd_dict=None)\n" +"--\n" +"\n" +"Create a compressor object for compressing data incrementally.\n" +"\n" +" level\n" +" The compression level to use, defaults to ZSTD_CLEVEL_DEFAULT.\n" +" options\n" +" A dict object that contains advanced compression parameters.\n" +" zstd_dict\n" +" A ZstdDict object, a pre-trained zstd dictionary.\n" +"\n" +"Thread-safe at method level. For one-shot compression, use the compress()\n" +"function instead."); + +static int +_zstd_ZstdCompressor___init___impl(ZstdCompressor *self, PyObject *level, + PyObject *options, PyObject *zstd_dict); + +static int +_zstd_ZstdCompressor___init__(PyObject *self, PyObject *args, PyObject *kwargs) +{ + int return_value = -1; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 3 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(level), &_Py_ID(options), &_Py_ID(zstd_dict), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"level", "options", "zstd_dict", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "ZstdCompressor", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[3]; + PyObject * const *fastargs; + Py_ssize_t nargs = PyTuple_GET_SIZE(args); + Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 0; + PyObject *level = Py_None; + PyObject *options = Py_None; + PyObject *zstd_dict = Py_None; + + fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, + /*minpos*/ 0, /*maxpos*/ 3, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!fastargs) { + goto exit; + } + if (!noptargs) { + goto skip_optional_pos; + } + if (fastargs[0]) { + level = fastargs[0]; + if (!--noptargs) { + goto skip_optional_pos; + } + } + if (fastargs[1]) { + options = fastargs[1]; + if (!--noptargs) { + goto skip_optional_pos; + } + } + zstd_dict = fastargs[2]; +skip_optional_pos: + return_value = _zstd_ZstdCompressor___init___impl((ZstdCompressor *)self, level, options, zstd_dict); + +exit: + return return_value; +} + +PyDoc_STRVAR(_zstd_ZstdCompressor_compress__doc__, +"compress($self, /, data, mode=ZstdCompressor.CONTINUE)\n" +"--\n" +"\n" +"Provide data to the compressor object.\n" +"\n" +" mode\n" +" Can be these 3 values ZstdCompressor.CONTINUE,\n" +" ZstdCompressor.FLUSH_BLOCK, ZstdCompressor.FLUSH_FRAME\n" +"\n" +"Return a chunk of compressed data if possible, or b\'\' otherwise. When you have\n" +"finished providing data to the compressor, call the flush() method to finish\n" +"the compression process."); + +#define _ZSTD_ZSTDCOMPRESSOR_COMPRESS_METHODDEF \ + {"compress", _PyCFunction_CAST(_zstd_ZstdCompressor_compress), METH_FASTCALL|METH_KEYWORDS, _zstd_ZstdCompressor_compress__doc__}, + +static PyObject * +_zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, + int mode); + +static PyObject * +_zstd_ZstdCompressor_compress(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 2 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(data), &_Py_ID(mode), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"data", "mode", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "compress", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[2]; + Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 1; + Py_buffer data = {NULL, NULL}; + int mode = ZSTD_e_continue; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, + /*minpos*/ 1, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!args) { + goto exit; + } + if (PyObject_GetBuffer(args[0], &data, PyBUF_SIMPLE) != 0) { + goto exit; + } + if (!noptargs) { + goto skip_optional_pos; + } + mode = PyLong_AsInt(args[1]); + if (mode == -1 && PyErr_Occurred()) { + goto exit; + } +skip_optional_pos: + return_value = _zstd_ZstdCompressor_compress_impl((ZstdCompressor *)self, &data, mode); + +exit: + /* Cleanup for data */ + if (data.obj) { + PyBuffer_Release(&data); + } + + return return_value; +} + +PyDoc_STRVAR(_zstd_ZstdCompressor_flush__doc__, +"flush($self, /, mode=ZstdCompressor.FLUSH_FRAME)\n" +"--\n" +"\n" +"Finish the compression process.\n" +"\n" +" mode\n" +" Can be these 2 values ZstdCompressor.FLUSH_FRAME,\n" +" ZstdCompressor.FLUSH_BLOCK\n" +"\n" +"Flush any remaining data left in internal buffers. Since zstd data consists\n" +"of one or more independent frames, the compressor object can still be used\n" +"after this method is called."); + +#define _ZSTD_ZSTDCOMPRESSOR_FLUSH_METHODDEF \ + {"flush", _PyCFunction_CAST(_zstd_ZstdCompressor_flush), METH_FASTCALL|METH_KEYWORDS, _zstd_ZstdCompressor_flush__doc__}, + +static PyObject * +_zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode); + +static PyObject * +_zstd_ZstdCompressor_flush(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 1 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(mode), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"mode", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "flush", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[1]; + Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 0; + int mode = ZSTD_e_end; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, + /*minpos*/ 0, /*maxpos*/ 1, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!args) { + goto exit; + } + if (!noptargs) { + goto skip_optional_pos; + } + mode = PyLong_AsInt(args[0]); + if (mode == -1 && PyErr_Occurred()) { + goto exit; + } +skip_optional_pos: + return_value = _zstd_ZstdCompressor_flush_impl((ZstdCompressor *)self, mode); + +exit: + return return_value; +} +/*[clinic end generated code: output=ef69eab155be39f6 input=a9049054013a1b77]*/ diff --git a/Modules/_zstd/clinic/decompressor.c.h b/Modules/_zstd/clinic/decompressor.c.h new file mode 100644 index 00000000000000..9359c637203f8f --- /dev/null +++ b/Modules/_zstd/clinic/decompressor.c.h @@ -0,0 +1,230 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +#if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) +# include "pycore_gc.h" // PyGC_Head +# include "pycore_runtime.h" // _Py_ID() +#endif +#include "pycore_abstract.h" // _PyNumber_Index() +#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION() +#include "pycore_modsupport.h" // _PyArg_UnpackKeywords() + +PyDoc_STRVAR(_zstd_ZstdDecompressor___init____doc__, +"ZstdDecompressor(zstd_dict=None, options=None)\n" +"--\n" +"\n" +"Create a decompressor object for decompressing data incrementally.\n" +"\n" +" zstd_dict\n" +" A ZstdDict object, a pre-trained zstd dictionary.\n" +" options\n" +" A dict object that contains advanced decompression parameters.\n" +"\n" +"Thread-safe at method level. For one-shot decompression, use the decompress()\n" +"function instead."); + +static int +_zstd_ZstdDecompressor___init___impl(ZstdDecompressor *self, + PyObject *zstd_dict, PyObject *options); + +static int +_zstd_ZstdDecompressor___init__(PyObject *self, PyObject *args, PyObject *kwargs) +{ + int return_value = -1; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 2 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(zstd_dict), &_Py_ID(options), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"zstd_dict", "options", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "ZstdDecompressor", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[2]; + PyObject * const *fastargs; + Py_ssize_t nargs = PyTuple_GET_SIZE(args); + Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 0; + PyObject *zstd_dict = Py_None; + PyObject *options = Py_None; + + fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, + /*minpos*/ 0, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!fastargs) { + goto exit; + } + if (!noptargs) { + goto skip_optional_pos; + } + if (fastargs[0]) { + zstd_dict = fastargs[0]; + if (!--noptargs) { + goto skip_optional_pos; + } + } + options = fastargs[1]; +skip_optional_pos: + return_value = _zstd_ZstdDecompressor___init___impl((ZstdDecompressor *)self, zstd_dict, options); + +exit: + return return_value; +} + +PyDoc_STRVAR(_zstd_ZstdDecompressor_unused_data__doc__, +"A bytes object of un-consumed input data.\n" +"\n" +"When ZstdDecompressor object stops after a frame is\n" +"decompressed, unused input data after the frame. Otherwise this will be b\'\'."); +#if defined(_zstd_ZstdDecompressor_unused_data_DOCSTR) +# undef _zstd_ZstdDecompressor_unused_data_DOCSTR +#endif +#define _zstd_ZstdDecompressor_unused_data_DOCSTR _zstd_ZstdDecompressor_unused_data__doc__ + +#if !defined(_zstd_ZstdDecompressor_unused_data_DOCSTR) +# define _zstd_ZstdDecompressor_unused_data_DOCSTR NULL +#endif +#if defined(_ZSTD_ZSTDDECOMPRESSOR_UNUSED_DATA_GETSETDEF) +# undef _ZSTD_ZSTDDECOMPRESSOR_UNUSED_DATA_GETSETDEF +# define _ZSTD_ZSTDDECOMPRESSOR_UNUSED_DATA_GETSETDEF {"unused_data", (getter)_zstd_ZstdDecompressor_unused_data_get, (setter)_zstd_ZstdDecompressor_unused_data_set, _zstd_ZstdDecompressor_unused_data_DOCSTR}, +#else +# define _ZSTD_ZSTDDECOMPRESSOR_UNUSED_DATA_GETSETDEF {"unused_data", (getter)_zstd_ZstdDecompressor_unused_data_get, NULL, _zstd_ZstdDecompressor_unused_data_DOCSTR}, +#endif + +static PyObject * +_zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self); + +static PyObject * +_zstd_ZstdDecompressor_unused_data_get(PyObject *self, void *Py_UNUSED(context)) +{ + PyObject *return_value = NULL; + + Py_BEGIN_CRITICAL_SECTION(self); + return_value = _zstd_ZstdDecompressor_unused_data_get_impl((ZstdDecompressor *)self); + Py_END_CRITICAL_SECTION(); + + return return_value; +} + +PyDoc_STRVAR(_zstd_ZstdDecompressor_decompress__doc__, +"decompress($self, /, data, max_length=-1)\n" +"--\n" +"\n" +"Decompress *data*, returning uncompressed bytes if possible, or b\'\' otherwise.\n" +"\n" +" data\n" +" A bytes-like object, zstd data to be decompressed.\n" +" max_length\n" +" Maximum size of returned data. When it is negative, the size of\n" +" output buffer is unlimited. When it is nonnegative, returns at\n" +" most max_length bytes of decompressed data.\n" +"\n" +"If *max_length* is nonnegative, returns at most *max_length* bytes of\n" +"decompressed data. If this limit is reached and further output can be\n" +"produced, *self.needs_input* will be set to ``False``. In this case, the next\n" +"call to *decompress()* may provide *data* as b\'\' to obtain more of the output.\n" +"\n" +"If all of the input data was decompressed and returned (either because this\n" +"was less than *max_length* bytes, or because *max_length* was negative),\n" +"*self.needs_input* will be set to True.\n" +"\n" +"Attempting to decompress data after the end of a frame is reached raises an\n" +"EOFError. Any data found after the end of the frame is ignored and saved in\n" +"the self.unused_data attribute."); + +#define _ZSTD_ZSTDDECOMPRESSOR_DECOMPRESS_METHODDEF \ + {"decompress", _PyCFunction_CAST(_zstd_ZstdDecompressor_decompress), METH_FASTCALL|METH_KEYWORDS, _zstd_ZstdDecompressor_decompress__doc__}, + +static PyObject * +_zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self, + Py_buffer *data, + Py_ssize_t max_length); + +static PyObject * +_zstd_ZstdDecompressor_decompress(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 2 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(data), &_Py_ID(max_length), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"data", "max_length", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "decompress", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[2]; + Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 1; + Py_buffer data = {NULL, NULL}; + Py_ssize_t max_length = -1; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, + /*minpos*/ 1, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!args) { + goto exit; + } + if (PyObject_GetBuffer(args[0], &data, PyBUF_SIMPLE) != 0) { + goto exit; + } + if (!noptargs) { + goto skip_optional_pos; + } + { + Py_ssize_t ival = -1; + PyObject *iobj = _PyNumber_Index(args[1]); + if (iobj != NULL) { + ival = PyLong_AsSsize_t(iobj); + Py_DECREF(iobj); + } + if (ival == -1 && PyErr_Occurred()) { + goto exit; + } + max_length = ival; + } +skip_optional_pos: + return_value = _zstd_ZstdDecompressor_decompress_impl((ZstdDecompressor *)self, &data, max_length); + +exit: + /* Cleanup for data */ + if (data.obj) { + PyBuffer_Release(&data); + } + + return return_value; +} +/*[clinic end generated code: output=ae703f0465a2906d input=a9049054013a1b77]*/ diff --git a/Modules/_zstd/clinic/zdict.c.h b/Modules/_zstd/clinic/zdict.c.h new file mode 100644 index 00000000000000..4e0f7b64172a74 --- /dev/null +++ b/Modules/_zstd/clinic/zdict.c.h @@ -0,0 +1,207 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +#if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) +# include "pycore_gc.h" // PyGC_Head +# include "pycore_runtime.h" // _Py_ID() +#endif +#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION() +#include "pycore_modsupport.h" // _PyArg_UnpackKeywords() + +PyDoc_STRVAR(_zstd_ZstdDict___init____doc__, +"ZstdDict(dict_content, is_raw=False)\n" +"--\n" +"\n" +"Represents a zstd dictionary, which can be used for compression/decompression.\n" +"\n" +" dict_content\n" +" A bytes-like object, dictionary\'s content.\n" +" is_raw\n" +" This parameter is for advanced user. True means dict_content\n" +" argument is a \"raw content\" dictionary, free of any format\n" +" restriction. False means dict_content argument is an ordinary\n" +" zstd dictionary, was created by zstd functions, follow a\n" +" specified format.\n" +"\n" +"It\'s thread-safe, and can be shared by multiple ZstdCompressor /\n" +"ZstdDecompressor objects."); + +static int +_zstd_ZstdDict___init___impl(ZstdDict *self, PyObject *dict_content, + int is_raw); + +static int +_zstd_ZstdDict___init__(PyObject *self, PyObject *args, PyObject *kwargs) +{ + int return_value = -1; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 2 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(dict_content), &_Py_ID(is_raw), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"dict_content", "is_raw", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "ZstdDict", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[2]; + PyObject * const *fastargs; + Py_ssize_t nargs = PyTuple_GET_SIZE(args); + Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 1; + PyObject *dict_content; + int is_raw = 0; + + fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, + /*minpos*/ 1, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!fastargs) { + goto exit; + } + dict_content = fastargs[0]; + if (!noptargs) { + goto skip_optional_pos; + } + is_raw = PyObject_IsTrue(fastargs[1]); + if (is_raw < 0) { + goto exit; + } +skip_optional_pos: + return_value = _zstd_ZstdDict___init___impl((ZstdDict *)self, dict_content, is_raw); + +exit: + return return_value; +} + +PyDoc_STRVAR(_zstd_ZstdDict_as_digested_dict__doc__, +"Load as a digested dictionary to compressor.\n" +"\n" +"Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_digested_dict)\n" +"1. Some advanced compression parameters of compressor may be overridden\n" +" by parameters of digested dictionary.\n" +"2. ZstdDict has a digested dictionaries cache for each compression level.\n" +" It\'s faster when loading again a digested dictionary with the same\n" +" compression level.\n" +"3. No need to use this for decompression."); +#if defined(_zstd_ZstdDict_as_digested_dict_DOCSTR) +# undef _zstd_ZstdDict_as_digested_dict_DOCSTR +#endif +#define _zstd_ZstdDict_as_digested_dict_DOCSTR _zstd_ZstdDict_as_digested_dict__doc__ + +#if !defined(_zstd_ZstdDict_as_digested_dict_DOCSTR) +# define _zstd_ZstdDict_as_digested_dict_DOCSTR NULL +#endif +#if defined(_ZSTD_ZSTDDICT_AS_DIGESTED_DICT_GETSETDEF) +# undef _ZSTD_ZSTDDICT_AS_DIGESTED_DICT_GETSETDEF +# define _ZSTD_ZSTDDICT_AS_DIGESTED_DICT_GETSETDEF {"as_digested_dict", (getter)_zstd_ZstdDict_as_digested_dict_get, (setter)_zstd_ZstdDict_as_digested_dict_set, _zstd_ZstdDict_as_digested_dict_DOCSTR}, +#else +# define _ZSTD_ZSTDDICT_AS_DIGESTED_DICT_GETSETDEF {"as_digested_dict", (getter)_zstd_ZstdDict_as_digested_dict_get, NULL, _zstd_ZstdDict_as_digested_dict_DOCSTR}, +#endif + +static PyObject * +_zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self); + +static PyObject * +_zstd_ZstdDict_as_digested_dict_get(PyObject *self, void *Py_UNUSED(context)) +{ + PyObject *return_value = NULL; + + Py_BEGIN_CRITICAL_SECTION(self); + return_value = _zstd_ZstdDict_as_digested_dict_get_impl((ZstdDict *)self); + Py_END_CRITICAL_SECTION(); + + return return_value; +} + +PyDoc_STRVAR(_zstd_ZstdDict_as_undigested_dict__doc__, +"Load as an undigested dictionary to compressor.\n" +"\n" +"Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_undigested_dict)\n" +"1. The advanced compression parameters of compressor will not be overridden.\n" +"2. Loading an undigested dictionary is costly. If load an undigested dictionary\n" +" multiple times, consider reusing a compressor object.\n" +"3. No need to use this for decompression."); +#if defined(_zstd_ZstdDict_as_undigested_dict_DOCSTR) +# undef _zstd_ZstdDict_as_undigested_dict_DOCSTR +#endif +#define _zstd_ZstdDict_as_undigested_dict_DOCSTR _zstd_ZstdDict_as_undigested_dict__doc__ + +#if !defined(_zstd_ZstdDict_as_undigested_dict_DOCSTR) +# define _zstd_ZstdDict_as_undigested_dict_DOCSTR NULL +#endif +#if defined(_ZSTD_ZSTDDICT_AS_UNDIGESTED_DICT_GETSETDEF) +# undef _ZSTD_ZSTDDICT_AS_UNDIGESTED_DICT_GETSETDEF +# define _ZSTD_ZSTDDICT_AS_UNDIGESTED_DICT_GETSETDEF {"as_undigested_dict", (getter)_zstd_ZstdDict_as_undigested_dict_get, (setter)_zstd_ZstdDict_as_undigested_dict_set, _zstd_ZstdDict_as_undigested_dict_DOCSTR}, +#else +# define _ZSTD_ZSTDDICT_AS_UNDIGESTED_DICT_GETSETDEF {"as_undigested_dict", (getter)_zstd_ZstdDict_as_undigested_dict_get, NULL, _zstd_ZstdDict_as_undigested_dict_DOCSTR}, +#endif + +static PyObject * +_zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self); + +static PyObject * +_zstd_ZstdDict_as_undigested_dict_get(PyObject *self, void *Py_UNUSED(context)) +{ + PyObject *return_value = NULL; + + Py_BEGIN_CRITICAL_SECTION(self); + return_value = _zstd_ZstdDict_as_undigested_dict_get_impl((ZstdDict *)self); + Py_END_CRITICAL_SECTION(); + + return return_value; +} + +PyDoc_STRVAR(_zstd_ZstdDict_as_prefix__doc__, +"Load as a prefix to compressor/decompressor.\n" +"\n" +"Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_prefix)\n" +"1. Prefix is compatible with long distance matching, while dictionary is not.\n" +"2. It only works for the first frame, then the compressor/decompressor will\n" +" return to no prefix state.\n" +"3. When decompressing, must use the same prefix as when compressing.\""); +#if defined(_zstd_ZstdDict_as_prefix_DOCSTR) +# undef _zstd_ZstdDict_as_prefix_DOCSTR +#endif +#define _zstd_ZstdDict_as_prefix_DOCSTR _zstd_ZstdDict_as_prefix__doc__ + +#if !defined(_zstd_ZstdDict_as_prefix_DOCSTR) +# define _zstd_ZstdDict_as_prefix_DOCSTR NULL +#endif +#if defined(_ZSTD_ZSTDDICT_AS_PREFIX_GETSETDEF) +# undef _ZSTD_ZSTDDICT_AS_PREFIX_GETSETDEF +# define _ZSTD_ZSTDDICT_AS_PREFIX_GETSETDEF {"as_prefix", (getter)_zstd_ZstdDict_as_prefix_get, (setter)_zstd_ZstdDict_as_prefix_set, _zstd_ZstdDict_as_prefix_DOCSTR}, +#else +# define _ZSTD_ZSTDDICT_AS_PREFIX_GETSETDEF {"as_prefix", (getter)_zstd_ZstdDict_as_prefix_get, NULL, _zstd_ZstdDict_as_prefix_DOCSTR}, +#endif + +static PyObject * +_zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self); + +static PyObject * +_zstd_ZstdDict_as_prefix_get(PyObject *self, void *Py_UNUSED(context)) +{ + PyObject *return_value = NULL; + + Py_BEGIN_CRITICAL_SECTION(self); + return_value = _zstd_ZstdDict_as_prefix_get_impl((ZstdDict *)self); + Py_END_CRITICAL_SECTION(); + + return return_value; +} +/*[clinic end generated code: output=59257c053f74eda7 input=a9049054013a1b77]*/ diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c new file mode 100644 index 00000000000000..d0f677be821572 --- /dev/null +++ b/Modules/_zstd/compressor.c @@ -0,0 +1,707 @@ +/* +Low level interface to Meta's zstd library for use in the compression.zstd +Python module. +*/ + +/* ZstdCompressor class definitions */ + +/*[clinic input] +module _zstd +class _zstd.ZstdCompressor "ZstdCompressor *" "clinic_state()->ZstdCompressor_type" +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=875bf614798f80cb]*/ + + +#ifndef Py_BUILD_CORE_BUILTIN +# define Py_BUILD_CORE_MODULE 1 +#endif + +#include "_zstdmodule.h" + +#include "buffer.h" + +#include // offsetof() + + +#define ZstdCompressor_CAST(op) ((ZstdCompressor *)op) + +int +_PyZstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, + const char *arg_name, const char* arg_type) +{ + size_t zstd_ret; + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state == NULL) { + return -1; + } + + /* Integer compression level */ + if (PyLong_Check(level_or_options)) { + int level = PyLong_AsInt(level_or_options); + if (level == -1 && PyErr_Occurred()) { + PyErr_Format(PyExc_ValueError, + "Compression level should be an int value between %d and %d.", + ZSTD_minCLevel(), ZSTD_maxCLevel()); + return -1; + } + + /* Save for generating ZSTD_CDICT */ + self->compression_level = level; + + /* Set compressionLevel to compression context */ + zstd_ret = ZSTD_CCtx_setParameter(self->cctx, + ZSTD_c_compressionLevel, + level); + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret); + return -1; + } + return 0; + } + + /* Options dict */ + if (PyDict_Check(level_or_options)) { + PyObject *key, *value; + Py_ssize_t pos = 0; + + while (PyDict_Next(level_or_options, &pos, &key, &value)) { + /* Check key type */ + if (Py_TYPE(key) == mod_state->DParameter_type) { + PyErr_SetString(PyExc_TypeError, + "Key of compression option dict should " + "NOT be DParameter."); + return -1; + } + + int key_v = PyLong_AsInt(key); + if (key_v == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "Key of options dict should be a CParameter attribute."); + return -1; + } + + // TODO(emmatyping): check bounds when there is a value error here for better + // error message? + int value_v = PyLong_AsInt(value); + if (value_v == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "Value of option dict should be an int."); + return -1; + } + + if (key_v == ZSTD_c_compressionLevel) { + /* Save for generating ZSTD_CDICT */ + self->compression_level = value_v; + } + else if (key_v == ZSTD_c_nbWorkers) { + /* From zstd library doc: + 1. When nbWorkers >= 1, triggers asynchronous mode when + used with ZSTD_compressStream2(). + 2, Default value is `0`, aka "single-threaded mode" : no + worker is spawned, compression is performed inside + caller's thread, all invocations are blocking. */ + if (value_v != 0) { + self->use_multithread = 1; + } + } + + /* Set parameter to compression context */ + zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v); + if (ZSTD_isError(zstd_ret)) { + set_parameter_error(mod_state, 1, key_v, value_v); + return -1; + } + } + return 0; + } + PyErr_Format(PyExc_TypeError, "Invalid type for %s. Expected %s", arg_name, arg_type); + return -1; +} + +static void +capsule_free_cdict(PyObject *capsule) +{ + ZSTD_CDict *cdict = PyCapsule_GetPointer(capsule, NULL); + ZSTD_freeCDict(cdict); +} + +ZSTD_CDict * +_get_CDict(ZstdDict *self, int compressionLevel) +{ + PyObject *level = NULL; + PyObject *capsule; + ZSTD_CDict *cdict; + + // TODO(emmatyping): refactor critical section code into a lock_held function + Py_BEGIN_CRITICAL_SECTION(self); + + /* int level object */ + level = PyLong_FromLong(compressionLevel); + if (level == NULL) { + goto error; + } + + /* Get PyCapsule object from self->c_dicts */ + capsule = PyDict_GetItemWithError(self->c_dicts, level); + if (capsule == NULL) { + if (PyErr_Occurred()) { + goto error; + } + + /* Create ZSTD_CDict instance */ + char *dict_buffer = PyBytes_AS_STRING(self->dict_content); + Py_ssize_t dict_len = Py_SIZE(self->dict_content); + Py_BEGIN_ALLOW_THREADS + cdict = ZSTD_createCDict(dict_buffer, + dict_len, + compressionLevel); + Py_END_ALLOW_THREADS + + if (cdict == NULL) { + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state != NULL) { + PyErr_SetString(mod_state->ZstdError, + "Failed to create ZSTD_CDict instance from zstd " + "dictionary content. Maybe the content is corrupted."); + } + goto error; + } + + /* Put ZSTD_CDict instance into PyCapsule object */ + capsule = PyCapsule_New(cdict, NULL, capsule_free_cdict); + if (capsule == NULL) { + ZSTD_freeCDict(cdict); + goto error; + } + + /* Add PyCapsule object to self->c_dicts */ + if (PyDict_SetItem(self->c_dicts, level, capsule) < 0) { + Py_DECREF(capsule); + goto error; + } + Py_DECREF(capsule); + } + else { + /* ZSTD_CDict instance already exists */ + cdict = PyCapsule_GetPointer(capsule, NULL); + } + goto success; + +error: + cdict = NULL; +success: + Py_XDECREF(level); + Py_END_CRITICAL_SECTION(); + return cdict; +} + +int +_PyZstd_load_c_dict(ZstdCompressor *self, PyObject *dict) { + + size_t zstd_ret; + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state == NULL) { + return -1; + } + ZstdDict *zd; + int type, ret; + + /* Check ZstdDict */ + ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type); + if (ret < 0) { + return -1; + } + else if (ret > 0) { + /* When compressing, use undigested dictionary by default. */ + zd = (ZstdDict*)dict; + type = DICT_TYPE_UNDIGESTED; + goto load; + } + + /* Check (ZstdDict, type) */ + if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { + /* Check ZstdDict */ + ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0), + (PyObject*)mod_state->ZstdDict_type); + if (ret < 0) { + return -1; + } + else if (ret > 0) { + /* type == -1 may indicate an error. */ + type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); + if (type == DICT_TYPE_DIGESTED || + type == DICT_TYPE_UNDIGESTED || + type == DICT_TYPE_PREFIX) + { + assert(type >= 0); + zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); + goto load; + } + } + } + + /* Wrong type */ + PyErr_SetString(PyExc_TypeError, + "zstd_dict argument should be ZstdDict object."); + return -1; + +load: + if (type == DICT_TYPE_DIGESTED) { + /* Get ZSTD_CDict */ + ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level); + if (c_dict == NULL) { + return -1; + } + /* Reference a prepared dictionary. + It overrides some compression context's parameters. */ + Py_BEGIN_CRITICAL_SECTION(self); + zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict); + Py_END_CRITICAL_SECTION(); + } + else if (type == DICT_TYPE_UNDIGESTED) { + /* Load a dictionary. + It doesn't override compression context's parameters. */ + Py_BEGIN_CRITICAL_SECTION2(self, zd); + zstd_ret = ZSTD_CCtx_loadDictionary( + self->cctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + Py_END_CRITICAL_SECTION2(); + } + else if (type == DICT_TYPE_PREFIX) { + /* Load a prefix */ + Py_BEGIN_CRITICAL_SECTION2(self, zd); + zstd_ret = ZSTD_CCtx_refPrefix( + self->cctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + Py_END_CRITICAL_SECTION2(); + } + else { + Py_UNREACHABLE(); + } + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret); + return -1; + } + return 0; +} + +#define clinic_state() (get_zstd_state_from_type(type)) +#include "clinic/compressor.c.h" +#undef clinic_state + +static PyObject * +_zstd_ZstdCompressor_new(PyTypeObject *type, PyObject *Py_UNUSED(args), PyObject *Py_UNUSED(kwargs)) +{ + ZstdCompressor *self; + self = PyObject_GC_New(ZstdCompressor, type); + if (self == NULL) { + goto error; + } + + self->inited = 0; + self->dict = NULL; + self->use_multithread = 0; + + + /* Compression context */ + self->cctx = ZSTD_createCCtx(); + if (self->cctx == NULL) { + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state != NULL) { + PyErr_SetString(mod_state->ZstdError, + "Unable to create ZSTD_CCtx instance."); + } + goto error; + } + + /* Last mode */ + self->last_mode = ZSTD_e_end; + + return (PyObject*)self; + +error: + if (self != NULL) { + PyObject_GC_Del(self); + } + return NULL; +} + +static void +ZstdCompressor_dealloc(PyObject *ob) +{ + ZstdCompressor *self = ZstdCompressor_CAST(ob); + + PyObject_GC_UnTrack(self); + + /* Free compression context */ + ZSTD_freeCCtx(self->cctx); + + /* Py_XDECREF the dict after free the compression context */ + Py_CLEAR(self->dict); + + PyTypeObject *tp = Py_TYPE(self); + PyObject_GC_Del(ob); + Py_DECREF(tp); +} + +/*[clinic input] +_zstd.ZstdCompressor.__init__ + + level: object = None + The compression level to use, defaults to ZSTD_CLEVEL_DEFAULT. + options: object = None + A dict object that contains advanced compression parameters. + zstd_dict: object = None + A ZstdDict object, a pre-trained zstd dictionary. + +Create a compressor object for compressing data incrementally. + +Thread-safe at method level. For one-shot compression, use the compress() +function instead. +[clinic start generated code]*/ + +static int +_zstd_ZstdCompressor___init___impl(ZstdCompressor *self, PyObject *level, + PyObject *options, PyObject *zstd_dict) +/*[clinic end generated code: output=215e6c4342732f96 input=9f79b0d8d34c8ef0]*/ +{ + /* Only called once */ + if (self->inited) { + PyErr_SetString(PyExc_RuntimeError, init_twice_msg); + return -1; + } + self->inited = 1; + + if (level != Py_None && options != Py_None) { + PyErr_SetString(PyExc_RuntimeError, "Only one of level or options should be used."); + return -1; + } + + /* Set compressLevel/options to compression context */ + if (level != Py_None) { + if (_PyZstd_set_c_parameters(self, level, "level", "int") < 0) { + return -1; + } + } + + if (options != Py_None) { + if (_PyZstd_set_c_parameters(self, options, "options", "dict") < 0) { + return -1; + } + } + + /* Load dictionary to compression context */ + if (zstd_dict != Py_None) { + if (_PyZstd_load_c_dict(self, zstd_dict) < 0) { + return -1; + } + + /* Py_INCREF the dict */ + Py_INCREF(zstd_dict); + self->dict = zstd_dict; + } + + // We can only start tracking self with the GC once self->dict is set. + PyObject_GC_Track(self); + return 0; +} + +PyObject * +compress_impl(ZstdCompressor *self, Py_buffer *data, + ZSTD_EndDirective end_directive) +{ + ZSTD_inBuffer in; + ZSTD_outBuffer out; + _BlocksOutputBuffer buffer = {.list = NULL}; + size_t zstd_ret; + PyObject *ret; + + /* Prepare input & output buffers */ + if (data != NULL) { + in.src = data->buf; + in.size = data->len; + in.pos = 0; + } + else { + in.src = ∈ + in.size = 0; + in.pos = 0; + } + + /* Calculate output buffer's size */ + size_t output_buffer_size = ZSTD_compressBound(in.size); + if (output_buffer_size > (size_t) PY_SSIZE_T_MAX) { + PyErr_NoMemory(); + goto error; + } + + if (_OutputBuffer_InitWithSize(&buffer, &out, -1, + (Py_ssize_t) output_buffer_size) < 0) { + goto error; + } + + + /* zstd stream compress */ + while (1) { + Py_BEGIN_ALLOW_THREADS + zstd_ret = ZSTD_compressStream2(self->cctx, &out, &in, end_directive); + Py_END_ALLOW_THREADS + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state != NULL) { + set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); + } + goto error; + } + + /* Finished */ + if (zstd_ret == 0) { + break; + } + + /* Output buffer should be exhausted, grow the buffer. */ + assert(out.pos == out.size); + if (out.pos == out.size) { + if (_OutputBuffer_Grow(&buffer, &out) < 0) { + goto error; + } + } + } + + /* Return a bytes object */ + ret = _OutputBuffer_Finish(&buffer, &out); + if (ret != NULL) { + return ret; + } + +error: + _OutputBuffer_OnError(&buffer); + return NULL; +} + +static PyObject * +compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data) +{ + ZSTD_inBuffer in; + ZSTD_outBuffer out; + _BlocksOutputBuffer buffer = {.list = NULL}; + size_t zstd_ret; + PyObject *ret; + + /* Prepare input & output buffers */ + in.src = data->buf; + in.size = data->len; + in.pos = 0; + + if (_OutputBuffer_InitAndGrow(&buffer, &out, -1) < 0) { + goto error; + } + + /* zstd stream compress */ + while (1) { + Py_BEGIN_ALLOW_THREADS + do { + zstd_ret = ZSTD_compressStream2(self->cctx, &out, &in, ZSTD_e_continue); + } while (out.pos != out.size && in.pos != in.size && !ZSTD_isError(zstd_ret)); + Py_END_ALLOW_THREADS + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state != NULL) { + set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); + } + goto error; + } + + /* Like compress_impl(), output as much as possible. */ + if (out.pos == out.size) { + if (_OutputBuffer_Grow(&buffer, &out) < 0) { + goto error; + } + } + else if (in.pos == in.size) { + /* Finished */ + assert(mt_continue_should_break(&in, &out)); + break; + } + } + + /* Return a bytes object */ + ret = _OutputBuffer_Finish(&buffer, &out); + if (ret != NULL) { + return ret; + } + +error: + _OutputBuffer_OnError(&buffer); + return NULL; +} + +/*[clinic input] +_zstd.ZstdCompressor.compress + + data: Py_buffer + mode: int(c_default="ZSTD_e_continue") = ZstdCompressor.CONTINUE + Can be these 3 values ZstdCompressor.CONTINUE, + ZstdCompressor.FLUSH_BLOCK, ZstdCompressor.FLUSH_FRAME + +Provide data to the compressor object. + +Return a chunk of compressed data if possible, or b'' otherwise. When you have +finished providing data to the compressor, call the flush() method to finish +the compression process. +[clinic start generated code]*/ + +static PyObject * +_zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, + int mode) +/*[clinic end generated code: output=ed7982d1cf7b4f98 input=ac2c21d180f579ea]*/ +{ + PyObject *ret; + + /* Check mode value */ + if (mode != ZSTD_e_continue && + mode != ZSTD_e_flush && + mode != ZSTD_e_end) + { + PyErr_SetString(PyExc_ValueError, + "mode argument wrong value, it should be one of " + "ZstdCompressor.CONTINUE, ZstdCompressor.FLUSH_BLOCK, " + "ZstdCompressor.FLUSH_FRAME."); + return NULL; + } + + /* Thread-safe code */ + Py_BEGIN_CRITICAL_SECTION(self); + + /* Compress */ + if (self->use_multithread && mode == ZSTD_e_continue) { + ret = compress_mt_continue_impl(self, data); + } + else { + ret = compress_impl(self, data, mode); + } + + if (ret) { + self->last_mode = mode; + } + else { + self->last_mode = ZSTD_e_end; + + /* Resetting cctx's session never fail */ + ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); + } + Py_END_CRITICAL_SECTION(); + + return ret; +} + +/*[clinic input] +_zstd.ZstdCompressor.flush + + mode: int(c_default="ZSTD_e_end") = ZstdCompressor.FLUSH_FRAME + Can be these 2 values ZstdCompressor.FLUSH_FRAME, + ZstdCompressor.FLUSH_BLOCK + +Finish the compression process. + +Flush any remaining data left in internal buffers. Since zstd data consists +of one or more independent frames, the compressor object can still be used +after this method is called. +[clinic start generated code]*/ + +static PyObject * +_zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) +/*[clinic end generated code: output=b7cf2c8d64dcf2e3 input=a766870301932b85]*/ +{ + PyObject *ret; + + /* Check mode value */ + if (mode != ZSTD_e_end && mode != ZSTD_e_flush) { + PyErr_SetString(PyExc_ValueError, + "mode argument wrong value, it should be " + "ZstdCompressor.FLUSH_FRAME or " + "ZstdCompressor.FLUSH_BLOCK."); + return NULL; + } + + /* Thread-safe code */ + Py_BEGIN_CRITICAL_SECTION(self); + ret = compress_impl(self, NULL, mode); + + if (ret) { + self->last_mode = mode; + } + else { + self->last_mode = ZSTD_e_end; + + /* Resetting cctx's session never fail */ + ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); + } + Py_END_CRITICAL_SECTION(); + + return ret; +} + +static PyMethodDef ZstdCompressor_methods[] = { + _ZSTD_ZSTDCOMPRESSOR_COMPRESS_METHODDEF + _ZSTD_ZSTDCOMPRESSOR_FLUSH_METHODDEF + + {0} +}; + +PyDoc_STRVAR(ZstdCompressor_last_mode_doc, +"The last mode used to this compressor object, its value can be .CONTINUE,\n" +".FLUSH_BLOCK, .FLUSH_FRAME. Initialized to .FLUSH_FRAME.\n\n" +"It can be used to get the current state of a compressor, such as, data flushed,\n" +"a frame ended."); + +static PyMemberDef ZstdCompressor_members[] = { + {"last_mode", Py_T_INT, offsetof(ZstdCompressor, last_mode), + Py_READONLY, ZstdCompressor_last_mode_doc}, + {0} +}; + +static int +ZstdCompressor_traverse(PyObject *ob, visitproc visit, void *arg) +{ + ZstdCompressor *self = ZstdCompressor_CAST(ob); + Py_VISIT(self->dict); + return 0; +} + +static int +ZstdCompressor_clear(PyObject *ob) +{ + ZstdCompressor *self = ZstdCompressor_CAST(ob); + Py_CLEAR(self->dict); + return 0; +} + +static PyType_Slot zstdcompressor_slots[] = { + {Py_tp_new, _zstd_ZstdCompressor_new}, + {Py_tp_dealloc, ZstdCompressor_dealloc}, + {Py_tp_init, _zstd_ZstdCompressor___init__}, + {Py_tp_methods, ZstdCompressor_methods}, + {Py_tp_members, ZstdCompressor_members}, + {Py_tp_doc, (char*)_zstd_ZstdCompressor___init____doc__}, + {Py_tp_traverse, ZstdCompressor_traverse}, + {Py_tp_clear, ZstdCompressor_clear}, + {0} +}; + +PyType_Spec zstdcompressor_type_spec = { + .name = "_zstd.ZstdCompressor", + .basicsize = sizeof(ZstdCompressor), + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, + .slots = zstdcompressor_slots, +}; diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c new file mode 100644 index 00000000000000..4e3a28068be130 --- /dev/null +++ b/Modules/_zstd/decompressor.c @@ -0,0 +1,891 @@ +/* +Low level interface to Meta's zstd library for use in the compression.zstd +Python module. +*/ + +/* ZstdDecompressor class definition */ + +/*[clinic input] +module _zstd +class _zstd.ZstdDecompressor "ZstdDecompressor *" "clinic_state()->ZstdDecompressor_type" +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=4e6eae327c0c0c76]*/ + +#ifndef Py_BUILD_CORE_BUILTIN +# define Py_BUILD_CORE_MODULE 1 +#endif + +#include "_zstdmodule.h" + +#include "buffer.h" + +#include // offsetof() + +#define ZstdDecompressor_CAST(op) ((ZstdDecompressor *)op) + +static inline ZSTD_DDict * +_get_DDict(ZstdDict *self) +{ + ZSTD_DDict *ret; + + /* Already created */ + if (self->d_dict != NULL) { + return self->d_dict; + } + + Py_BEGIN_CRITICAL_SECTION(self); + if (self->d_dict == NULL) { + /* Create ZSTD_DDict instance from dictionary content */ + char *dict_buffer = PyBytes_AS_STRING(self->dict_content); + Py_ssize_t dict_len = Py_SIZE(self->dict_content); + Py_BEGIN_ALLOW_THREADS + self->d_dict = ZSTD_createDDict(dict_buffer, + dict_len); + Py_END_ALLOW_THREADS + + if (self->d_dict == NULL) { + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state != NULL) { + PyErr_SetString(mod_state->ZstdError, + "Failed to create ZSTD_DDict instance from zstd " + "dictionary content. Maybe the content is corrupted."); + } + } + } + + /* Don't lose any exception */ + ret = self->d_dict; + Py_END_CRITICAL_SECTION(); + + return ret; +} + +/* Set decompression parameters to decompression context */ +int +_PyZstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) +{ + size_t zstd_ret; + PyObject *key, *value; + Py_ssize_t pos; + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state == NULL) { + return -1; + } + + if (!PyDict_Check(options)) { + PyErr_SetString(PyExc_TypeError, + "options argument should be dict object."); + return -1; + } + + pos = 0; + while (PyDict_Next(options, &pos, &key, &value)) { + /* Check key type */ + if (Py_TYPE(key) == mod_state->CParameter_type) { + PyErr_SetString(PyExc_TypeError, + "Key of decompression options dict should " + "NOT be CParameter."); + return -1; + } + + /* Both key & value should be 32-bit signed int */ + int key_v = PyLong_AsInt(key); + if (key_v == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "Key of options dict should be a DParameter attribute."); + return -1; + } + + // TODO(emmatyping): check bounds when there is a value error here for better + // error message? + int value_v = PyLong_AsInt(value); + if (value_v == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "Value of options dict should be an int."); + return -1; + } + + /* Set parameter to compression context */ + Py_BEGIN_CRITICAL_SECTION(self); + zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v); + Py_END_CRITICAL_SECTION(); + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + set_parameter_error(mod_state, 0, key_v, value_v); + return -1; + } + } + return 0; +} + +/* Load dictionary or prefix to decompression context */ +int +_PyZstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) +{ + size_t zstd_ret; + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state == NULL) { + return -1; + } + ZstdDict *zd; + int type, ret; + + /* Check ZstdDict */ + ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type); + if (ret < 0) { + return -1; + } + else if (ret > 0) { + /* When decompressing, use digested dictionary by default. */ + zd = (ZstdDict*)dict; + type = DICT_TYPE_DIGESTED; + goto load; + } + + /* Check (ZstdDict, type) */ + if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { + /* Check ZstdDict */ + ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0), + (PyObject*)mod_state->ZstdDict_type); + if (ret < 0) { + return -1; + } + else if (ret > 0) { + /* type == -1 may indicate an error. */ + type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); + if (type == DICT_TYPE_DIGESTED || + type == DICT_TYPE_UNDIGESTED || + type == DICT_TYPE_PREFIX) + { + assert(type >= 0); + zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); + goto load; + } + } + } + + /* Wrong type */ + PyErr_SetString(PyExc_TypeError, + "zstd_dict argument should be ZstdDict object."); + return -1; + +load: + if (type == DICT_TYPE_DIGESTED) { + /* Get ZSTD_DDict */ + ZSTD_DDict *d_dict = _get_DDict(zd); + if (d_dict == NULL) { + return -1; + } + /* Reference a prepared dictionary */ + Py_BEGIN_CRITICAL_SECTION(self); + zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict); + Py_END_CRITICAL_SECTION(); + } + else if (type == DICT_TYPE_UNDIGESTED) { + /* Load a dictionary */ + Py_BEGIN_CRITICAL_SECTION2(self, zd); + zstd_ret = ZSTD_DCtx_loadDictionary( + self->dctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + Py_END_CRITICAL_SECTION2(); + } + else if (type == DICT_TYPE_PREFIX) { + /* Load a prefix */ + Py_BEGIN_CRITICAL_SECTION2(self, zd); + zstd_ret = ZSTD_DCtx_refPrefix( + self->dctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + Py_END_CRITICAL_SECTION2(); + } + else { + /* Impossible code path */ + PyErr_SetString(PyExc_SystemError, + "load_d_dict() impossible code path"); + return -1; + } + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret); + return -1; + } + return 0; +} + + + +/* + Given the two types of decompressors (defined in _zstdmodule.h): + + typedef enum { + TYPE_DECOMPRESSOR, // , ZstdDecompressor class + TYPE_ENDLESS_DECOMPRESSOR, // , decompress() function + } decompress_type; + + Decompress implementation for , , pseudo code: + + initialize_output_buffer + while True: + decompress_data + set_object_flag # .eof for , .at_frame_edge for . + + if output_buffer_exhausted: + if output_buffer_reached_max_length: + finish + grow_output_buffer + elif input_buffer_exhausted: + finish + + ZSTD_decompressStream()'s size_t return value: + - 0 when a frame is completely decoded and fully flushed, zstd's internal + buffer has no data. + - An error code, which can be tested using ZSTD_isError(). + - Or any other value > 0, which means there is still some decoding or + flushing to do to complete current frame. + + Note, decompressing "an empty input" in any case will make it > 0. + + supports multiple frames, has an .at_frame_edge flag, it means both the + input and output streams are at a frame edge. The flag can be set by this + statement: + + .at_frame_edge = (zstd_ret == 0) ? 1 : 0 + + But if decompressing "an empty input" at "a frame edge", zstd_ret will be + non-zero, then .at_frame_edge will be wrongly set to false. To solve this + problem, two AFE checks are needed to ensure that: when at "a frame edge", + empty input will not be decompressed. + + // AFE check + if (self->at_frame_edge && in->pos == in->size) { + finish + } + + In , if .at_frame_edge is eventually set to true, but input stream has + unconsumed data (in->pos < in->size), then the outer function + stream_decompress() will set .at_frame_edge to false. In this case, + although the output stream is at a frame edge, for the caller, the input + stream is not at a frame edge, see below diagram. This behavior does not + affect the next AFE check, since (in->pos < in->size). + + input stream: --------------|--- + ^ + output stream: ====================| + ^ +*/ +PyObject * +decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in, + Py_ssize_t max_length, + Py_ssize_t initial_size, + decompress_type type) +{ + size_t zstd_ret; + ZSTD_outBuffer out; + _BlocksOutputBuffer buffer = {.list = NULL}; + PyObject *ret; + + /* The first AFE check for setting .at_frame_edge flag */ + if (type == TYPE_ENDLESS_DECOMPRESSOR) { + if (self->at_frame_edge && in->pos == in->size) { + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state == NULL) { + return NULL; + } + ret = mod_state->empty_bytes; + Py_INCREF(ret); + return ret; + } + } + + /* Initialize the output buffer */ + if (initial_size >= 0) { + if (_OutputBuffer_InitWithSize(&buffer, &out, max_length, initial_size) < 0) { + goto error; + } + } + else { + if (_OutputBuffer_InitAndGrow(&buffer, &out, max_length) < 0) { + goto error; + } + } + assert(out.pos == 0); + + while (1) { + /* Decompress */ + Py_BEGIN_ALLOW_THREADS + zstd_ret = ZSTD_decompressStream(self->dctx, &out, in); + Py_END_ALLOW_THREADS + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state != NULL) { + set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret); + } + goto error; + } + + /* Set .eof/.af_frame_edge flag */ + if (type == TYPE_DECOMPRESSOR) { + /* ZstdDecompressor class stops when a frame is decompressed */ + if (zstd_ret == 0) { + self->eof = 1; + break; + } + } + else if (type == TYPE_ENDLESS_DECOMPRESSOR) { + /* decompress() function supports multiple frames */ + self->at_frame_edge = (zstd_ret == 0) ? 1 : 0; + + /* The second AFE check for setting .at_frame_edge flag */ + if (self->at_frame_edge && in->pos == in->size) { + break; + } + } + + /* Need to check out before in. Maybe zstd's internal buffer still has + a few bytes can be output, grow the buffer and continue. */ + if (out.pos == out.size) { + /* Output buffer exhausted */ + + /* Output buffer reached max_length */ + if (_OutputBuffer_ReachedMaxLength(&buffer, &out)) { + break; + } + + /* Grow output buffer */ + if (_OutputBuffer_Grow(&buffer, &out) < 0) { + goto error; + } + assert(out.pos == 0); + + } + else if (in->pos == in->size) { + /* Finished */ + break; + } + } + + /* Return a bytes object */ + ret = _OutputBuffer_Finish(&buffer, &out); + if (ret != NULL) { + return ret; + } + +error: + _OutputBuffer_OnError(&buffer); + return NULL; +} + +void +decompressor_reset_session(ZstdDecompressor *self, + decompress_type type) +{ + // TODO(emmatyping): use _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED here + // and ensure lock is always held + + /* Reset variables */ + self->in_begin = 0; + self->in_end = 0; + + if (type == TYPE_DECOMPRESSOR) { + Py_CLEAR(self->unused_data); + } + + /* Reset variables in one operation */ + self->needs_input = 1; + self->at_frame_edge = 1; + self->eof = 0; + self->_unused_char_for_align = 0; + + /* Resetting session never fail */ + ZSTD_DCtx_reset(self->dctx, ZSTD_reset_session_only); +} + +PyObject * +stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length, + decompress_type type) +{ + Py_ssize_t initial_buffer_size = -1; + ZSTD_inBuffer in; + PyObject *ret = NULL; + int use_input_buffer; + + if (type == TYPE_DECOMPRESSOR) { + /* Check .eof flag */ + if (self->eof) { + PyErr_SetString(PyExc_EOFError, "Already at the end of a zstd frame."); + assert(ret == NULL); + goto success; + } + } + else if (type == TYPE_ENDLESS_DECOMPRESSOR) { + /* Fast path for the first frame */ + if (self->at_frame_edge && self->in_begin == self->in_end) { + /* Read decompressed size */ + uint64_t decompressed_size = ZSTD_getFrameContentSize(data->buf, data->len); + + /* These two zstd constants always > PY_SSIZE_T_MAX: + ZSTD_CONTENTSIZE_UNKNOWN is (0ULL - 1) + ZSTD_CONTENTSIZE_ERROR is (0ULL - 2) + + Use ZSTD_findFrameCompressedSize() to check complete frame, + prevent allocating too much memory for small input chunk. */ + + if (decompressed_size <= (uint64_t) PY_SSIZE_T_MAX && + !ZSTD_isError(ZSTD_findFrameCompressedSize(data->buf, data->len)) ) + { + initial_buffer_size = (Py_ssize_t) decompressed_size; + } + } + } + + /* Prepare input buffer w/wo unconsumed data */ + if (self->in_begin == self->in_end) { + /* No unconsumed data */ + use_input_buffer = 0; + + in.src = data->buf; + in.size = data->len; + in.pos = 0; + } + else if (data->len == 0) { + /* Has unconsumed data, fast path for b'' */ + assert(self->in_begin < self->in_end); + + use_input_buffer = 1; + + in.src = self->input_buffer + self->in_begin; + in.size = self->in_end - self->in_begin; + in.pos = 0; + } + else { + /* Has unconsumed data */ + use_input_buffer = 1; + + /* Unconsumed data size in input_buffer */ + size_t used_now = self->in_end - self->in_begin; + assert(self->in_end > self->in_begin); + + /* Number of bytes we can append to input buffer */ + size_t avail_now = self->input_buffer_size - self->in_end; + assert(self->input_buffer_size >= self->in_end); + + /* Number of bytes we can append if we move existing contents to + beginning of buffer */ + size_t avail_total = self->input_buffer_size - used_now; + assert(self->input_buffer_size >= used_now); + + if (avail_total < (size_t) data->len) { + char *tmp; + size_t new_size = used_now + data->len; + + /* Allocate with new size */ + tmp = PyMem_Malloc(new_size); + if (tmp == NULL) { + PyErr_NoMemory(); + goto error; + } + + /* Copy unconsumed data to the beginning of new buffer */ + memcpy(tmp, + self->input_buffer + self->in_begin, + used_now); + + /* Switch to new buffer */ + PyMem_Free(self->input_buffer); + self->input_buffer = tmp; + self->input_buffer_size = new_size; + + /* Set begin & end position */ + self->in_begin = 0; + self->in_end = used_now; + } + else if (avail_now < (size_t) data->len) { + /* Move unconsumed data to the beginning. + Overlap is possible, so use memmove(). */ + memmove(self->input_buffer, + self->input_buffer + self->in_begin, + used_now); + + /* Set begin & end position */ + self->in_begin = 0; + self->in_end = used_now; + } + + /* Copy data to input buffer */ + memcpy(self->input_buffer + self->in_end, data->buf, data->len); + self->in_end += data->len; + + in.src = self->input_buffer + self->in_begin; + in.size = used_now + data->len; + in.pos = 0; + } + assert(in.pos == 0); + + /* Decompress */ + ret = decompress_impl(self, &in, + max_length, initial_buffer_size, + type); + if (ret == NULL) { + goto error; + } + + /* Unconsumed input data */ + if (in.pos == in.size) { + if (type == TYPE_DECOMPRESSOR) { + if (Py_SIZE(ret) == max_length || self->eof) { + self->needs_input = 0; + } + else { + self->needs_input = 1; + } + } + else if (type == TYPE_ENDLESS_DECOMPRESSOR) { + if (Py_SIZE(ret) == max_length && !self->at_frame_edge) { + self->needs_input = 0; + } + else { + self->needs_input = 1; + } + } + + if (use_input_buffer) { + /* Clear input_buffer */ + self->in_begin = 0; + self->in_end = 0; + } + } + else { + size_t data_size = in.size - in.pos; + + self->needs_input = 0; + + if (type == TYPE_ENDLESS_DECOMPRESSOR) { + self->at_frame_edge = 0; + } + + if (!use_input_buffer) { + /* Discard buffer if it's too small + (resizing it may needlessly copy the current contents) */ + if (self->input_buffer != NULL && + self->input_buffer_size < data_size) + { + PyMem_Free(self->input_buffer); + self->input_buffer = NULL; + self->input_buffer_size = 0; + } + + /* Allocate if necessary */ + if (self->input_buffer == NULL) { + self->input_buffer = PyMem_Malloc(data_size); + if (self->input_buffer == NULL) { + PyErr_NoMemory(); + goto error; + } + self->input_buffer_size = data_size; + } + + /* Copy unconsumed data */ + memcpy(self->input_buffer, (char*)in.src + in.pos, data_size); + self->in_begin = 0; + self->in_end = data_size; + } + else { + /* Use input buffer */ + self->in_begin += in.pos; + } + } + + goto success; + +error: + /* Reset decompressor's states/session */ + decompressor_reset_session(self, type); + + Py_CLEAR(ret); +success: + + return ret; +} + + +static PyObject * +_zstd_ZstdDecompressor_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + ZstdDecompressor *self; + self = PyObject_GC_New(ZstdDecompressor, type); + if (self == NULL) { + goto error; + } + + self->inited = 0; + self->dict = NULL; + self->input_buffer = NULL; + self->input_buffer_size = 0; + self->in_begin = -1; + self->in_end = -1; + self->unused_data = NULL; + self->eof = 0; + + /* needs_input flag */ + self->needs_input = 1; + + /* at_frame_edge flag */ + self->at_frame_edge = 1; + + /* Decompression context */ + self->dctx = ZSTD_createDCtx(); + if (self->dctx == NULL) { + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state != NULL) { + PyErr_SetString(mod_state->ZstdError, + "Unable to create ZSTD_DCtx instance."); + } + goto error; + } + + return (PyObject*)self; + +error: + if (self != NULL) { + PyObject_GC_Del(self); + } + return NULL; +} + +static void +ZstdDecompressor_dealloc(PyObject *ob) +{ + ZstdDecompressor *self = ZstdDecompressor_CAST(ob); + + PyObject_GC_UnTrack(self); + + /* Free decompression context */ + ZSTD_freeDCtx(self->dctx); + + /* Py_CLEAR the dict after free decompression context */ + Py_CLEAR(self->dict); + + /* Free unconsumed input data buffer */ + PyMem_Free(self->input_buffer); + + /* Free unused data */ + Py_CLEAR(self->unused_data); + + PyTypeObject *tp = Py_TYPE(self); + PyObject_GC_Del(ob); + Py_DECREF(tp); +} + +/*[clinic input] +_zstd.ZstdDecompressor.__init__ + + zstd_dict: object = None + A ZstdDict object, a pre-trained zstd dictionary. + options: object = None + A dict object that contains advanced decompression parameters. + +Create a decompressor object for decompressing data incrementally. + +Thread-safe at method level. For one-shot decompression, use the decompress() +function instead. +[clinic start generated code]*/ + +static int +_zstd_ZstdDecompressor___init___impl(ZstdDecompressor *self, + PyObject *zstd_dict, PyObject *options) +/*[clinic end generated code: output=703af2f1ec226642 input=8fd72999acc1a146]*/ +{ + /* Only called once */ + if (self->inited) { + PyErr_SetString(PyExc_RuntimeError, init_twice_msg); + return -1; + } + self->inited = 1; + + /* Load dictionary to decompression context */ + if (zstd_dict != Py_None) { + if (_PyZstd_load_d_dict(self, zstd_dict) < 0) { + return -1; + } + + /* Py_INCREF the dict */ + Py_INCREF(zstd_dict); + self->dict = zstd_dict; + } + + /* Set option to decompression context */ + if (options != Py_None) { + if (_PyZstd_set_d_parameters(self, options) < 0) { + return -1; + } + } + + // We can only start tracking self with the GC once self->dict is set. + PyObject_GC_Track(self); + return 0; +} + +/*[clinic input] +@critical_section +@getter +_zstd.ZstdDecompressor.unused_data + +A bytes object of un-consumed input data. + +When ZstdDecompressor object stops after a frame is +decompressed, unused input data after the frame. Otherwise this will be b''. +[clinic start generated code]*/ + +static PyObject * +_zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self) +/*[clinic end generated code: output=f3a20940f11b6b09 input=5233800bef00df04]*/ +{ + PyObject *ret; + + /* Thread-safe code */ + Py_BEGIN_CRITICAL_SECTION(self); + + if (!self->eof) { + _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (mod_state == NULL) { + return NULL; + } + ret = mod_state->empty_bytes; + Py_INCREF(ret); + } + else { + if (self->unused_data == NULL) { + self->unused_data = PyBytes_FromStringAndSize( + self->input_buffer + self->in_begin, + self->in_end - self->in_begin); + ret = self->unused_data; + Py_XINCREF(ret); + } + else { + ret = self->unused_data; + Py_INCREF(ret); + } + } + + Py_END_CRITICAL_SECTION(); + + return ret; +} + +/*[clinic input] +_zstd.ZstdDecompressor.decompress + + data: Py_buffer + A bytes-like object, zstd data to be decompressed. + max_length: Py_ssize_t = -1 + Maximum size of returned data. When it is negative, the size of + output buffer is unlimited. When it is nonnegative, returns at + most max_length bytes of decompressed data. + +Decompress *data*, returning uncompressed bytes if possible, or b'' otherwise. + +If *max_length* is nonnegative, returns at most *max_length* bytes of +decompressed data. If this limit is reached and further output can be +produced, *self.needs_input* will be set to ``False``. In this case, the next +call to *decompress()* may provide *data* as b'' to obtain more of the output. + +If all of the input data was decompressed and returned (either because this +was less than *max_length* bytes, or because *max_length* was negative), +*self.needs_input* will be set to True. + +Attempting to decompress data after the end of a frame is reached raises an +EOFError. Any data found after the end of the frame is ignored and saved in +the self.unused_data attribute. +[clinic start generated code]*/ + +static PyObject * +_zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self, + Py_buffer *data, + Py_ssize_t max_length) +/*[clinic end generated code: output=a4302b3c940dbec6 input=830e455bc9a50b6e]*/ +{ + PyObject *ret; + /* Thread-safe code */ + Py_BEGIN_CRITICAL_SECTION(self); + + ret = stream_decompress(self, data, max_length, TYPE_DECOMPRESSOR); + Py_END_CRITICAL_SECTION(); + return ret; +} + +#define clinic_state() (get_zstd_state_from_type(type)) +#include "clinic/decompressor.c.h" +#undef clinic_state + +static PyMethodDef ZstdDecompressor_methods[] = { + _ZSTD_ZSTDDECOMPRESSOR_DECOMPRESS_METHODDEF + + {0} +}; + +PyDoc_STRVAR(ZstdDecompressor_eof_doc, +"True means the end of the first frame has been reached. If decompress data\n" +"after that, an EOFError exception will be raised."); + +PyDoc_STRVAR(ZstdDecompressor_needs_input_doc, +"If the max_length output limit in .decompress() method has been reached, and\n" +"the decompressor has (or may has) unconsumed input data, it will be set to\n" +"False. In this case, pass b'' to .decompress() method may output further data."); + +static PyMemberDef ZstdDecompressor_members[] = { + {"eof", Py_T_BOOL, offsetof(ZstdDecompressor, eof), + Py_READONLY, ZstdDecompressor_eof_doc}, + + {"needs_input", Py_T_BOOL, offsetof(ZstdDecompressor, needs_input), + Py_READONLY, ZstdDecompressor_needs_input_doc}, + + {0} +}; + +static PyGetSetDef ZstdDecompressor_getset[] = { + _ZSTD_ZSTDDECOMPRESSOR_UNUSED_DATA_GETSETDEF + + {0} +}; + +static int +ZstdDecompressor_traverse(PyObject *ob, visitproc visit, void *arg) +{ + ZstdDecompressor *self = ZstdDecompressor_CAST(ob); + Py_VISIT(self->dict); + return 0; +} + +static int +ZstdDecompressor_clear(PyObject *ob) +{ + ZstdDecompressor *self = ZstdDecompressor_CAST(ob); + Py_CLEAR(self->dict); + Py_CLEAR(self->unused_data); + return 0; +} + +static PyType_Slot ZstdDecompressor_slots[] = { + {Py_tp_new, _zstd_ZstdDecompressor_new}, + {Py_tp_dealloc, ZstdDecompressor_dealloc}, + {Py_tp_init, _zstd_ZstdDecompressor___init__}, + {Py_tp_methods, ZstdDecompressor_methods}, + {Py_tp_members, ZstdDecompressor_members}, + {Py_tp_getset, ZstdDecompressor_getset}, + {Py_tp_doc, (char*)_zstd_ZstdDecompressor___init____doc__}, + {Py_tp_traverse, ZstdDecompressor_traverse}, + {Py_tp_clear, ZstdDecompressor_clear}, + {0} +}; + +PyType_Spec ZstdDecompressor_type_spec = { + .name = "_zstd.ZstdDecompressor", + .basicsize = sizeof(ZstdDecompressor), + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, + .slots = ZstdDecompressor_slots, +}; diff --git a/Modules/_zstd/zdict.c b/Modules/_zstd/zdict.c new file mode 100644 index 00000000000000..28ab964a6caa87 --- /dev/null +++ b/Modules/_zstd/zdict.c @@ -0,0 +1,286 @@ +/* +Low level interface to Meta's zstd library for use in the compression.zstd +Python module. +*/ + +/* ZstdDict class definitions */ + +/*[clinic input] +module _zstd +class _zstd.ZstdDict "ZstdDict *" "clinic_state()->ZstdDict_type" +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=a5d1254c497e52ba]*/ + +#ifndef Py_BUILD_CORE_BUILTIN +# define Py_BUILD_CORE_MODULE 1 +#endif + +#include "_zstdmodule.h" + +#include // offsetof() + +#define ZstdDict_CAST(op) ((ZstdDict *)op) + +static PyObject * +_zstd_ZstdDict_new(PyTypeObject *type, PyObject *Py_UNUSED(args), PyObject *Py_UNUSED(kwargs)) +{ + ZstdDict *self; + self = PyObject_GC_New(ZstdDict, type); + if (self == NULL) { + goto error; + } + + self->dict_content = NULL; + self->inited = 0; + self->d_dict = NULL; + + /* ZSTD_CDict dict */ + self->c_dicts = PyDict_New(); + if (self->c_dicts == NULL) { + goto error; + } + + return (PyObject*)self; + +error: + if (self != NULL) { + PyObject_GC_Del(self); + } + return NULL; +} + +static void +ZstdDict_dealloc(PyObject *ob) +{ + ZstdDict *self = ZstdDict_CAST(ob); + + PyObject_GC_UnTrack(self); + + /* Free ZSTD_DDict instance */ + ZSTD_freeDDict(self->d_dict); + + /* Release dict_content after Free ZSTD_CDict/ZSTD_DDict instances */ + Py_CLEAR(self->dict_content); + Py_CLEAR(self->c_dicts); + + PyTypeObject *tp = Py_TYPE(self); + PyObject_GC_Del(ob); + Py_DECREF(tp); +} + +/*[clinic input] +_zstd.ZstdDict.__init__ + + dict_content: object + A bytes-like object, dictionary's content. + is_raw: bool = False + This parameter is for advanced user. True means dict_content + argument is a "raw content" dictionary, free of any format + restriction. False means dict_content argument is an ordinary + zstd dictionary, was created by zstd functions, follow a + specified format. + +Represents a zstd dictionary, which can be used for compression/decompression. + +It's thread-safe, and can be shared by multiple ZstdCompressor / +ZstdDecompressor objects. +[clinic start generated code]*/ + +static int +_zstd_ZstdDict___init___impl(ZstdDict *self, PyObject *dict_content, + int is_raw) +/*[clinic end generated code: output=c5f5a0d8377d037c input=e6750f62a513b3ee]*/ +{ + /* Only called once */ + if (self->inited) { + PyErr_SetString(PyExc_RuntimeError, init_twice_msg); + return -1; + } + self->inited = 1; + + /* Check dict_content's type */ + self->dict_content = PyBytes_FromObject(dict_content); + if (self->dict_content == NULL) { + PyErr_SetString(PyExc_TypeError, + "dict_content argument should be bytes-like object."); + return -1; + } + + /* Both ordinary dictionary and "raw content" dictionary should + at least 8 bytes */ + if (Py_SIZE(self->dict_content) < 8) { + PyErr_SetString(PyExc_ValueError, + "Zstd dictionary content should at least 8 bytes."); + return -1; + } + + /* Get dict_id, 0 means "raw content" dictionary. */ + self->dict_id = ZSTD_getDictID_fromDict(PyBytes_AS_STRING(self->dict_content), + Py_SIZE(self->dict_content)); + + /* Check validity for ordinary dictionary */ + if (!is_raw && self->dict_id == 0) { + char *msg = "The dict_content argument is not a valid zstd " + "dictionary. The first 4 bytes of a valid zstd dictionary " + "should be a magic number: b'\\x37\\xA4\\x30\\xEC'.\n" + "If you are an advanced user, and can be sure that " + "dict_content argument is a \"raw content\" zstd " + "dictionary, set is_raw parameter to True."; + PyErr_SetString(PyExc_ValueError, msg); + return -1; + } + + // Can only track self once self->dict_content is included + PyObject_GC_Track(self); + return 0; +} + +#define clinic_state() (get_zstd_state(type)) +#include "clinic/zdict.c.h" +#undef clinic_state + +PyDoc_STRVAR(ZstdDict_dictid_doc, +"ID of zstd dictionary, a 32-bit unsigned int value.\n\n" +"Non-zero means ordinary dictionary, was created by zstd functions, follow\n" +"a specified format.\n\n" +"0 means a \"raw content\" dictionary, free of any format restriction, used\n" +"for advanced user."); + +PyDoc_STRVAR(ZstdDict_dictcontent_doc, +"The content of zstd dictionary, a bytes object, it's the same as dict_content\n" +"argument in ZstdDict.__init__() method. It can be used with other programs."); + +static PyObject * +ZstdDict_str(PyObject *ob) +{ + ZstdDict *dict = ZstdDict_CAST(ob); + return PyUnicode_FromFormat("", + dict->dict_id, Py_SIZE(dict->dict_content)); +} + +static PyMemberDef ZstdDict_members[] = { + {"dict_id", Py_T_UINT, offsetof(ZstdDict, dict_id), Py_READONLY, ZstdDict_dictid_doc}, + {"dict_content", Py_T_OBJECT_EX, offsetof(ZstdDict, dict_content), Py_READONLY, ZstdDict_dictcontent_doc}, + {0} +}; + +/*[clinic input] +@critical_section +@getter +_zstd.ZstdDict.as_digested_dict + +Load as a digested dictionary to compressor. + +Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_digested_dict) +1. Some advanced compression parameters of compressor may be overridden + by parameters of digested dictionary. +2. ZstdDict has a digested dictionaries cache for each compression level. + It's faster when loading again a digested dictionary with the same + compression level. +3. No need to use this for decompression. +[clinic start generated code]*/ + +static PyObject * +_zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self) +/*[clinic end generated code: output=09b086e7a7320dbb input=585448c79f31f74a]*/ +{ + return Py_BuildValue("Oi", self, DICT_TYPE_DIGESTED); +} + +/*[clinic input] +@critical_section +@getter +_zstd.ZstdDict.as_undigested_dict + +Load as an undigested dictionary to compressor. + +Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_undigested_dict) +1. The advanced compression parameters of compressor will not be overridden. +2. Loading an undigested dictionary is costly. If load an undigested dictionary + multiple times, consider reusing a compressor object. +3. No need to use this for decompression. +[clinic start generated code]*/ + +static PyObject * +_zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self) +/*[clinic end generated code: output=43c7a989e6d4253a input=022b0829ffb1c220]*/ +{ + return Py_BuildValue("Oi", self, DICT_TYPE_UNDIGESTED); +} + +/*[clinic input] +@critical_section +@getter +_zstd.ZstdDict.as_prefix + +Load as a prefix to compressor/decompressor. + +Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_prefix) +1. Prefix is compatible with long distance matching, while dictionary is not. +2. It only works for the first frame, then the compressor/decompressor will + return to no prefix state. +3. When decompressing, must use the same prefix as when compressing." +[clinic start generated code]*/ + +static PyObject * +_zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self) +/*[clinic end generated code: output=6f7130c356595a16 input=09fb82a6a5407e87]*/ +{ + return Py_BuildValue("Oi", self, DICT_TYPE_PREFIX); +} + +static PyGetSetDef ZstdDict_getset[] = { + _ZSTD_ZSTDDICT_AS_DIGESTED_DICT_GETSETDEF + + _ZSTD_ZSTDDICT_AS_UNDIGESTED_DICT_GETSETDEF + + _ZSTD_ZSTDDICT_AS_PREFIX_GETSETDEF + + {0} +}; + +static Py_ssize_t +ZstdDict_length(PyObject *ob) +{ + ZstdDict *self = ZstdDict_CAST(ob); + assert(PyBytes_Check(self->dict_content)); + return Py_SIZE(self->dict_content); +} + +static int +ZstdDict_traverse(PyObject *ob, visitproc visit, void *arg) +{ + ZstdDict *self = ZstdDict_CAST(ob); + Py_VISIT(self->c_dicts); + Py_VISIT(self->dict_content); + return 0; +} + +static int +ZstdDict_clear(PyObject *ob) +{ + ZstdDict *self = ZstdDict_CAST(ob); + Py_CLEAR(self->dict_content); + return 0; +} + +static PyType_Slot zstddict_slots[] = { + {Py_tp_members, ZstdDict_members}, + {Py_tp_getset, ZstdDict_getset}, + {Py_tp_new, _zstd_ZstdDict_new}, + {Py_tp_dealloc, ZstdDict_dealloc}, + {Py_tp_init, _zstd_ZstdDict___init__}, + {Py_tp_str, ZstdDict_str}, + {Py_tp_doc, (char*)_zstd_ZstdDict___init____doc__}, + {Py_sq_length, ZstdDict_length}, + {Py_tp_traverse, ZstdDict_traverse}, + {Py_tp_clear, ZstdDict_clear}, + {0} +}; + +PyType_Spec zstddict_type_spec = { + .name = "_zstd.ZstdDict", + .basicsize = sizeof(ZstdDict), + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, + .slots = zstddict_slots, +}; diff --git a/Python/stdlib_module_names.h b/Python/stdlib_module_names.h index fcef7419bd397b..92cc5afedc8429 100644 --- a/Python/stdlib_module_names.h +++ b/Python/stdlib_module_names.h @@ -103,6 +103,7 @@ static const char* _Py_stdlib_module_names[] = { "_winapi", "_wmi", "_zoneinfo", +"_zstd", "abc", "annotationlib", "antigravity", diff --git a/Tools/c-analyzer/cpython/ignored.tsv b/Tools/c-analyzer/cpython/ignored.tsv index a33619b1b345e2..2be3e1a420b91d 100644 --- a/Tools/c-analyzer/cpython/ignored.tsv +++ b/Tools/c-analyzer/cpython/ignored.tsv @@ -748,6 +748,7 @@ Modules/expat/xmlrole.c - error - ## other Modules/_io/_iomodule.c - _PyIO_Module - Modules/_sqlite/module.c - _sqlite3module - +Modules/_zstd/_zstdmodule.c - _zstdmodule - Modules/clinic/md5module.c.h _md5_md5 _keywords - Modules/clinic/grpmodule.c.h grp_getgrgid _keywords - Modules/clinic/grpmodule.c.h grp_getgrnam _keywords - diff --git a/configure b/configure index 205f196a25af2d..9a33e8d1875a22 100755 --- a/configure +++ b/configure @@ -678,6 +678,8 @@ MODULE__HASHLIB_FALSE MODULE__HASHLIB_TRUE MODULE__SSL_FALSE MODULE__SSL_TRUE +MODULE__ZSTD_FALSE +MODULE__ZSTD_TRUE MODULE__LZMA_FALSE MODULE__LZMA_TRUE MODULE__BZ2_FALSE @@ -852,6 +854,8 @@ HAVE_GETHOSTBYNAME_R_3_ARG HAVE_GETHOSTBYNAME_R_5_ARG HAVE_GETHOSTBYNAME_R_6_ARG LIBOBJS +LIBZSTD_LIBS +LIBZSTD_CFLAGS LIBLZMA_LIBS LIBLZMA_CFLAGS BZIP2_LIBS @@ -1172,6 +1176,8 @@ BZIP2_CFLAGS BZIP2_LIBS LIBLZMA_CFLAGS LIBLZMA_LIBS +LIBZSTD_CFLAGS +LIBZSTD_LIBS LIBREADLINE_CFLAGS LIBREADLINE_LIBS LIBEDIT_CFLAGS @@ -2011,6 +2017,10 @@ Some influential environment variables: C compiler flags for LIBLZMA, overriding pkg-config LIBLZMA_LIBS linker flags for LIBLZMA, overriding pkg-config + LIBZSTD_CFLAGS + C compiler flags for LIBZSTD, overriding pkg-config + LIBZSTD_LIBS + linker flags for LIBZSTD, overriding pkg-config LIBREADLINE_CFLAGS C compiler flags for LIBREADLINE, overriding pkg-config LIBREADLINE_LIBS @@ -22405,6 +22415,260 @@ printf "%s\n" "yes" >&6; } fi +pkg_failed=no +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for libzstd" >&5 +printf %s "checking for libzstd... " >&6; } + +if test -n "$LIBZSTD_CFLAGS"; then + pkg_cv_LIBZSTD_CFLAGS="$LIBZSTD_CFLAGS" + elif test -n "$PKG_CONFIG"; then + if test -n "$PKG_CONFIG" && \ + { { printf "%s\n" "$as_me:${as_lineno-$LINENO}: \$PKG_CONFIG --exists --print-errors \"libzstd\""; } >&5 + ($PKG_CONFIG --exists --print-errors "libzstd") 2>&5 + ac_status=$? + printf "%s\n" "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; }; then + pkg_cv_LIBZSTD_CFLAGS=`$PKG_CONFIG --cflags "libzstd" 2>/dev/null` + test "x$?" != "x0" && pkg_failed=yes +else + pkg_failed=yes +fi + else + pkg_failed=untried +fi +if test -n "$LIBZSTD_LIBS"; then + pkg_cv_LIBZSTD_LIBS="$LIBZSTD_LIBS" + elif test -n "$PKG_CONFIG"; then + if test -n "$PKG_CONFIG" && \ + { { printf "%s\n" "$as_me:${as_lineno-$LINENO}: \$PKG_CONFIG --exists --print-errors \"libzstd\""; } >&5 + ($PKG_CONFIG --exists --print-errors "libzstd") 2>&5 + ac_status=$? + printf "%s\n" "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; }; then + pkg_cv_LIBZSTD_LIBS=`$PKG_CONFIG --libs "libzstd" 2>/dev/null` + test "x$?" != "x0" && pkg_failed=yes +else + pkg_failed=yes +fi + else + pkg_failed=untried +fi + + + +if test $pkg_failed = yes; then + { printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: no" >&5 +printf "%s\n" "no" >&6; } + +if $PKG_CONFIG --atleast-pkgconfig-version 0.20; then + _pkg_short_errors_supported=yes +else + _pkg_short_errors_supported=no +fi + if test $_pkg_short_errors_supported = yes; then + LIBZSTD_PKG_ERRORS=`$PKG_CONFIG --short-errors --print-errors --cflags --libs "libzstd" 2>&1` + else + LIBZSTD_PKG_ERRORS=`$PKG_CONFIG --print-errors --cflags --libs "libzstd" 2>&1` + fi + # Put the nasty error message in config.log where it belongs + echo "$LIBZSTD_PKG_ERRORS" >&5 + + + save_CFLAGS=$CFLAGS +save_CPPFLAGS=$CPPFLAGS +save_LDFLAGS=$LDFLAGS +save_LIBS=$LIBS + + + CPPFLAGS="$CPPFLAGS $LIBZSTD_CFLAGS" + LIBS="$LIBS $LIBZSTD_LIBS" + for ac_header in zstd.h zdict.h +do : + as_ac_Header=`printf "%s\n" "ac_cv_header_$ac_header" | sed "$as_sed_sh"` +ac_fn_c_check_header_compile "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default" +if eval test \"x\$"$as_ac_Header"\" = x"yes" +then : + cat >>confdefs.h <<_ACEOF +#define `printf "%s\n" "HAVE_$ac_header" | sed "$as_sed_cpp"` 1 +_ACEOF + + { printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for ZSTD_compress in -lzstd" >&5 +printf %s "checking for ZSTD_compress in -lzstd... " >&6; } +if test ${ac_cv_lib_zstd_ZSTD_compress+y} +then : + printf %s "(cached) " >&6 +else case e in #( + e) ac_check_lib_save_LIBS=$LIBS +LIBS="-lzstd $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. + The 'extern "C"' is for builds by C++ compilers; + although this is not generally supported in C code supporting it here + has little cost and some practical benefit (sr 110532). */ +#ifdef __cplusplus +extern "C" +#endif +char ZSTD_compress (void); +int +main (void) +{ +return ZSTD_compress (); + ; + return 0; +} +_ACEOF +if ac_fn_c_try_link "$LINENO" +then : + ac_cv_lib_zstd_ZSTD_compress=yes +else case e in #( + e) ac_cv_lib_zstd_ZSTD_compress=no ;; +esac +fi +rm -f core conftest.err conftest.$ac_objext conftest.beam \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS ;; +esac +fi +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_zstd_ZSTD_compress" >&5 +printf "%s\n" "$ac_cv_lib_zstd_ZSTD_compress" >&6; } +if test "x$ac_cv_lib_zstd_ZSTD_compress" = xyes +then : + have_libzstd=yes +else case e in #( + e) have_libzstd=no ;; +esac +fi + + +else case e in #( + e) have_libzstd=no ;; +esac +fi + +done + if test "x$have_libzstd" = xyes +then : + + LIBZSTD_CFLAGS=${LIBZSTD_CFLAGS-""} + LIBZSTD_LIBS=${LIBZSTD_LIBS-"-lzstd"} + +fi + +CFLAGS=$save_CFLAGS +CPPFLAGS=$save_CPPFLAGS +LDFLAGS=$save_LDFLAGS +LIBS=$save_LIBS + + + +elif test $pkg_failed = untried; then + { printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: no" >&5 +printf "%s\n" "no" >&6; } + + save_CFLAGS=$CFLAGS +save_CPPFLAGS=$CPPFLAGS +save_LDFLAGS=$LDFLAGS +save_LIBS=$LIBS + + + CPPFLAGS="$CPPFLAGS $LIBZSTD_CFLAGS" + LIBS="$LIBS $LIBZSTD_LIBS" + for ac_header in zstd.h zdict.h +do : + as_ac_Header=`printf "%s\n" "ac_cv_header_$ac_header" | sed "$as_sed_sh"` +ac_fn_c_check_header_compile "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default" +if eval test \"x\$"$as_ac_Header"\" = x"yes" +then : + cat >>confdefs.h <<_ACEOF +#define `printf "%s\n" "HAVE_$ac_header" | sed "$as_sed_cpp"` 1 +_ACEOF + + { printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for ZSTD_compress in -lzstd" >&5 +printf %s "checking for ZSTD_compress in -lzstd... " >&6; } +if test ${ac_cv_lib_zstd_ZSTD_compress+y} +then : + printf %s "(cached) " >&6 +else case e in #( + e) ac_check_lib_save_LIBS=$LIBS +LIBS="-lzstd $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. + The 'extern "C"' is for builds by C++ compilers; + although this is not generally supported in C code supporting it here + has little cost and some practical benefit (sr 110532). */ +#ifdef __cplusplus +extern "C" +#endif +char ZSTD_compress (void); +int +main (void) +{ +return ZSTD_compress (); + ; + return 0; +} +_ACEOF +if ac_fn_c_try_link "$LINENO" +then : + ac_cv_lib_zstd_ZSTD_compress=yes +else case e in #( + e) ac_cv_lib_zstd_ZSTD_compress=no ;; +esac +fi +rm -f core conftest.err conftest.$ac_objext conftest.beam \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS ;; +esac +fi +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_zstd_ZSTD_compress" >&5 +printf "%s\n" "$ac_cv_lib_zstd_ZSTD_compress" >&6; } +if test "x$ac_cv_lib_zstd_ZSTD_compress" = xyes +then : + have_libzstd=yes +else case e in #( + e) have_libzstd=no ;; +esac +fi + + +else case e in #( + e) have_libzstd=no ;; +esac +fi + +done + if test "x$have_libzstd" = xyes +then : + + LIBZSTD_CFLAGS=${LIBZSTD_CFLAGS-""} + LIBZSTD_LIBS=${LIBZSTD_LIBS-"-lzstd"} + +fi + +CFLAGS=$save_CFLAGS +CPPFLAGS=$save_CPPFLAGS +LDFLAGS=$save_LDFLAGS +LIBS=$save_LIBS + + + +else + LIBZSTD_CFLAGS=$pkg_cv_LIBZSTD_CFLAGS + LIBZSTD_LIBS=$pkg_cv_LIBZSTD_LIBS + { printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: yes" >&5 +printf "%s\n" "yes" >&6; } + have_libzstd=yes +fi + + @@ -29406,6 +29670,7 @@ SRCDIRS="\ Modules/_xxtestfuzz \ Modules/cjkcodecs \ Modules/expat \ + Modules/_zstd \ Objects \ Objects/mimalloc \ Objects/mimalloc/prim \ @@ -32982,6 +33247,46 @@ fi printf "%s\n" "$py_cv_module__lzma" >&6; } + { printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for stdlib extension module _zstd" >&5 +printf %s "checking for stdlib extension module _zstd... " >&6; } + if test "$py_cv_module__zstd" != "n/a" +then : + + if true +then : + if test "$have_libzstd" = yes +then : + py_cv_module__zstd=yes +else case e in #( + e) py_cv_module__zstd=missing ;; +esac +fi +else case e in #( + e) py_cv_module__zstd=disabled ;; +esac +fi + +fi + as_fn_append MODULE_BLOCK "MODULE__ZSTD_STATE=$py_cv_module__zstd$as_nl" + if test "x$py_cv_module__zstd" = xyes +then : + + as_fn_append MODULE_BLOCK "MODULE__ZSTD_CFLAGS=$LIBZSTD_CFLAGS$as_nl" + as_fn_append MODULE_BLOCK "MODULE__ZSTD_LDFLAGS=$LIBZSTD_LIBS$as_nl" + +fi + if test "$py_cv_module__zstd" = yes; then + MODULE__ZSTD_TRUE= + MODULE__ZSTD_FALSE='#' +else + MODULE__ZSTD_TRUE='#' + MODULE__ZSTD_FALSE= +fi + + { printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: $py_cv_module__zstd" >&5 +printf "%s\n" "$py_cv_module__zstd" >&6; } + + { printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for stdlib extension module _ssl" >&5 printf %s "checking for stdlib extension module _ssl... " >&6; } @@ -34050,6 +34355,10 @@ if test -z "${MODULE__LZMA_TRUE}" && test -z "${MODULE__LZMA_FALSE}"; then as_fn_error $? "conditional \"MODULE__LZMA\" was never defined. Usually this means the macro was only invoked conditionally." "$LINENO" 5 fi +if test -z "${MODULE__ZSTD_TRUE}" && test -z "${MODULE__ZSTD_FALSE}"; then + as_fn_error $? "conditional \"MODULE__ZSTD\" was never defined. +Usually this means the macro was only invoked conditionally." "$LINENO" 5 +fi if test -z "${MODULE__SSL_TRUE}" && test -z "${MODULE__SSL_FALSE}"; then as_fn_error $? "conditional \"MODULE__SSL\" was never defined. Usually this means the macro was only invoked conditionally." "$LINENO" 5 diff --git a/configure.ac b/configure.ac index f0ae7fbec1cbfe..d4eba4ac1bda83 100644 --- a/configure.ac +++ b/configure.ac @@ -5394,6 +5394,20 @@ PKG_CHECK_MODULES([LIBLZMA], [liblzma], [have_liblzma=yes], [ ]) ]) +PKG_CHECK_MODULES([LIBZSTD], [libzstd], [have_libzstd=yes], [ + WITH_SAVE_ENV([ + CPPFLAGS="$CPPFLAGS $LIBZSTD_CFLAGS" + LIBS="$LIBS $LIBZSTD_LIBS" + AC_CHECK_HEADERS([zstd.h zdict.h], [ + AC_CHECK_LIB([zstd], [ZSTD_compress], [have_libzstd=yes], [have_libzstd=no]) + ], [have_libzstd=no]) + AS_VAR_IF([have_libzstd], [yes], [ + LIBZSTD_CFLAGS=${LIBZSTD_CFLAGS-""} + LIBZSTD_LIBS=${LIBZSTD_LIBS-"-lzstd"} + ]) + ]) +]) + dnl PY_CHECK_NETDB_FUNC(FUNCTION) AC_DEFUN([PY_CHECK_NETDB_FUNC], [PY_CHECK_FUNC([$1], [@%:@include ])]) @@ -7095,6 +7109,7 @@ SRCDIRS="\ Modules/_xxtestfuzz \ Modules/cjkcodecs \ Modules/expat \ + Modules/_zstd \ Objects \ Objects/mimalloc \ Objects/mimalloc/prim \ @@ -8041,6 +8056,8 @@ PY_STDLIB_MOD([_bz2], [], [test "$have_bzip2" = yes], [$BZIP2_CFLAGS], [$BZIP2_LIBS]) PY_STDLIB_MOD([_lzma], [], [test "$have_liblzma" = yes], [$LIBLZMA_CFLAGS], [$LIBLZMA_LIBS]) +PY_STDLIB_MOD([_zstd], [], [test "$have_libzstd" = yes], + [$LIBZSTD_CFLAGS], [$LIBZSTD_LIBS]) dnl OpenSSL bindings PY_STDLIB_MOD([_ssl], [], [test "$ac_cv_working_openssl_ssl" = yes], diff --git a/pyconfig.h.in b/pyconfig.h.in index 6c17685e22a078..8c2c6ab5cea4cd 100644 --- a/pyconfig.h.in +++ b/pyconfig.h.in @@ -1630,12 +1630,18 @@ /* Define to 1 if you have the 'writev' function. */ #undef HAVE_WRITEV +/* Define to 1 if you have the header file. */ +#undef HAVE_ZDICT_H + /* Define if the zlib library has inflateCopy */ #undef HAVE_ZLIB_COPY /* Define to 1 if you have the header file. */ #undef HAVE_ZLIB_H +/* Define to 1 if you have the header file. */ +#undef HAVE_ZSTD_H + /* Define to 1 if you have the '_getpty' function. */ #undef HAVE__GETPTY