10000 Split out common compression routines into separate file (#5728) · RustPython/RustPython@9c88475 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9c88475

Browse files
authored
Split out common compression routines into separate file (#5728)
1 parent 6aa80aa commit 9c88475

File tree

5 files changed

+286
-285
lines changed

5 files changed

+286
-285
lines changed

stdlib/src/bz2.rs renamed to stdlib/src/compression/bz2.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ pub(crate) use _bz2::make_module;
44

55
#[pymodule]
66
mod _bz2 {
7+
use super::super::{
8+
DecompressArgs, DecompressError, DecompressState, DecompressStatus, Decompressor,
9+
};
710
use crate::common::lock::PyMutex;
811
use crate::vm::{
912
VirtualMachine,
@@ -12,9 +15,6 @@ mod _bz2 {
1215
object::{PyPayload, PyResult},
1316
types::Constructor,
1417
};
15-
use crate::zlib::{
16-
DecompressArgs, DecompressError, DecompressState, DecompressStatus, Decompressor,
17-
};
1818
use bzip2::{Decompress, Status, write::BzEncoder};
1919
use rustpython_vm::convert::ToPyException;
2020
use std::{fmt, io::Write};
@@ -74,7 +74,7 @@ mod _bz2 {
7474
impl BZ2Decompressor {
7575
#[pymethod]
7676
fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
77-
let max_length = args.max_length();
77+
let max_length = args.max_length_negative_is_none();
7878
let data = &*args.data();
7979

8080
let mut state = self.state.lock();

stdlib/src/compression/generic.rs

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
// cspell:ignore chunker
2+
use crate::vm::{
3+
VirtualMachine,
4+
builtins::{PyBaseExceptionRef, PyBytesRef},
5+
convert::ToPyException,
6+
function::{ArgBytesLike, ArgSize, OptionalArg},
7+
};
8+
9+
#[derive(FromArgs)]
10+
pub(super) struct DecompressArgs {
11+
#[pyarg(positional)]
12+
data: ArgBytesLike,
13+
#[pyarg(any, optional)]
14+
pub max_length: OptionalArg<ArgSize>,
15+
}
16+
17+
impl DecompressArgs {
18+
pub fn data(&self) -> crate::common::borrow::BorrowedValue<'_, [u8]> {
19+
self.data.borrow_buf()
20+
}
21+
pub fn max_length_negative_is_none(&self) -> Option<usize> {
22+
self.max_length
23+
.into_option()
24+
.and_then(|ArgSize { value }| usize::try_from(value).ok())
25+
}
26+
}
27+
28+
pub(super) trait Decompressor {
29+
type Flush: FlushKind;
30+
type Status: DecompressStatus;
31+
type Error;
32+
33+
fn total_in(&self) -> u64;
34+
fn decompress_vec(
35+
&mut self,
36+
input: &[u8],
37+
output: &mut Vec<u8>,
38+
flush: Self::Flush,
39+
) -> Result<Self::Status, Self::Error>;
40+
fn maybe_set_dict(&mut self, err: Self::Error) -> Result<(), Self::Error> {
41+
Err(err)
42+
}
43+
}
44+
45+
pub(super) trait DecompressStatus {
46+
fn is_stream_end(&self) -> bool;
47+
}
48+
49+
pub(super) trait FlushKind: Copy {
50+
const SYNC: Self;
51+
}
52+
53+
impl FlushKind for () {
54+
const SYNC: Self = ();
55+
}
56+
57+
pub(super) fn flush_sync<T: FlushKind>(_final_chunk: bool) -> T {
58+
T::SYNC
59+
}
60+
61+
pub(super) const CHUNKSIZE: usize = u32::MAX as usize;
62+
63+
#[derive(Clone)]
64+
pub(super) struct Chunker<'a> {
65+
data1: &'a [u8],
66+
data2: &'a [u8],
67+
}
68+
impl<'a> Chunker<'a> {
69+
pub fn new(data: &'a [u8]) -> Self {
70+
Self {
71+
data1: data,
72+
data2: &[],
73+
}
74+
}
75+
pub fn chain(data1: &'a [u8], data2: &'a [u8]) -> Self {
76+
if data1.is_empty() {
77+
Self {
78+
data1: data2,
79+
data2: &[],
80+
}
81+
} else {
82+
Self { data1, data2 }
83+
}
84+
}
85+
pub fn len(&self) -> usize {
86+
self.data1.len() + self.data2.len()
87+
}
88+
pub fn is_empty(&self) -> bool {
89+
self.data1.is_empty()
90+
}
91+
pub fn to_vec(&self) -> Vec<u8> {
92+
[self.data1, self.data2].concat()
93+
}
94+
pub fn chunk(&self) -> &'a [u8] {
95+
self.data1.get(..CHUNKSIZE).unwrap_or(self.data1)
96+
}
97+
pub fn advance(&mut self, consumed: usize) {
98+
self.data1 = &self.data1[consumed..];
99+
if self.data1.is_empty() {
100+
self.data1 = std::mem::take(&mut self.data2);
101+
}
102+
}
103+
}
104+
105+
pub(super) fn _decompress<D: Decompressor>(
106+
data: &[u8],
107+
d: &mut D,
108+
bufsize: usize,
109+
max_length: Option<usize>,
110+
calc_flush: impl Fn(bool) -> D::Flush,
111+
) -> Result<(Vec<u8>, bool), D::Error> {
112+
let mut data = Chunker::new(data);
113+
_decompress_chunks(&mut data, d, bufsize, max_length, calc_flush)
114+
}
115+
116+
pub(super) fn _decompress_chunks<D: Decompressor>(
117+
data: &mut Chunker<'_>,
118+
d: &mut D,
119+
bufsize: usize,
120+
max_length: Option<usize>,
121+
calc_flush: impl Fn(bool) -> D::Flush,
122+
) -> Result<(Vec<u8>, bool), D::Error> {
123+
if data.is_empty() {
124+
return Ok((Vec::new(), true));
125+
}
126+
let max_length = max_length.unwrap_or(usize::MAX);
127+
let mut buf = Vec::new();
128+
129+
'outer: loop {
130+
let chunk = data.chunk();
131+
let flush = calc_flush(chunk.len() == data.len());
132+
loop {
133+
let additional = std::cmp::min(bufsize, max_length - buf.capacity());
134+
if additional == 0 {
135+
return Ok((buf, false) 10000 );
136+
}
137+
buf.reserve_exact(additional);
138+
139+
let prev_in = d.total_in();
140+
let res = d.decompress_vec(chunk, &mut buf, flush);
141+
let consumed = d.total_in() - prev_in;
142+
143+
data.advance(consumed as usize);
144+
145+
match res {
146+
Ok(status) => {
147+
let stream_end = status.is_stream_end();
148+
if stream_end || data.is_empty() {
149+
// we've reached the end of the stream, we're done
150+
buf.shrink_to_fit();
151+
return Ok((buf, stream_end));
152+
} else if !chunk.is_empty() && consumed == 0 {
153+
// we're gonna need a bigger buffer
154+
continue;
155+
} else {
156+
// next chunk
157+
continue 'outer;
158+
}
159+
}
160+
Err(e) => {
161+
d.maybe_set_dict(e)?;
162+
// now try the next chunk
163+
continue 'outer;
164+
}
165+
};
166+
}
167+
}
168+
}
169+
170+
#[derive(Debug)]
171+
pub(super) struct DecompressState<D> {
172+
decompress: D,
173+
unused_data: PyBytesRef,
174+
input_buffer: Vec<u8>,
175+
eof: bool,
176+
needs_input: bool,
177+
}
178+
179+
impl<D: Decompressor> DecompressState<D> {
180+
pub fn new(decompress: D, vm: &VirtualMachine) -> Self {
181+
Self {
182+
decompress,
183+
unused_data: vm.ctx.empty_bytes.clone(),
184+
input_buffer: Vec::new(),
185+
eof: false,
186+
needs_input: true,
187+
}
188+
}
189+
190+
pub fn eof(&self) -> bool {
191+
self.eof
192+
}
193+
194+
pub fn unused_data(&self) -> PyBytesRef {
195+
self.unused_data.clone()
196+
}
197+
198+
pub fn needs_input(&self) -> bool {
199+
self.needs_input
200+
}
201+
202+
pub fn decompress(
203+
&mut self,
204+
data: &[u8],
205+
max_length: Option<usize>,
206+
bufsize: usize,
207+
vm: &VirtualMachine,
208+
) -> Result<Vec<u8>, DecompressError<D::Error>> {
209+
if self.eof {
210+
return Err(DecompressError::Eof(EofError));
211+
}
212+
213+
let input_buffer = &mut self.input_buffer;
214+
let d = &mut self.decompress;
215+
216+
let mut chunks = Chunker::chain(input_buffer, data);
217+
218+
let prev_len = chunks.len();
219+
let (ret, stream_end) =
220+
match _decompress_chunks(&mut chunks, d, bufsize, max_length, flush_sync) {
221+
Ok((buf, stream_end)) => (Ok(buf), stream_end),
222+
Err(err) => (Err(err), false),
223+
};
224+
let consumed = prev_len - chunks.len();
225+
226+
self.eof |= stream_end;
227+
228+
if self.eof {
229+
self.needs_input = false;
230+
if !chunks.is_empty() {
231+
self.unused_data = vm.ctx.new_bytes(chunks.to_vec());
232+
}
233+
} else if chunks.is_empty() {
234+
input_buffer.clear();
235+
self.needs_input = true;
236+
} else {
237+
self.needs_input = false;
238+
if let Some(n_consumed_from_data) = consumed.checked_sub(input_buffer.len()) {
239+
input_buffer.clear();
240+
input_buffer.extend_from_slice(&data[n_consumed_from_data..]);
241+
} else {
242+
input_buffer.drain(..consumed);
243+
input_buffer.extend_from_slice(data);
244+
}
245+
}
246+
247+
ret.map_err(DecompressError::Decompress)
248+
}
249+
}
250+
251+
pub(super) enum DecompressError<E> {
252+
Decompress(E),
253+
Eof(EofError),
254+
}
255+
256+
impl<E> From<E> for DecompressError<E> {
257+
fn from(err: E) -> Self {
258+
Self::Decompress(err)
259+
}
260+
}
261+
262+
pub(super) struct EofError;
263+
264+
impl ToPyException for EofError {
265+
fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
266+
vm.new_eof_error("End of stream already reached".to_owned())
267+
}
268+
}

stdlib/src/compression/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mod generic;
2+
use generic::*;
3+
4+
pub mod bz2;
5+
pub mod zlib;

0 commit comments

Comments
 (0)
0