diff --git a/Cargo.lock b/Cargo.lock index 1fa60eac29..241d2a595d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2355,6 +2355,7 @@ dependencies = [ "rand 0.9.0", "rustpython-literal", "siphasher 0.3.11", + "unicode_names2", "volatile", "widestring", "windows-sys 0.59.0", diff --git a/Lib/_pycodecs.py b/Lib/_pycodecs.py index 0741504cc9..d0efa9ad6b 100644 --- a/Lib/_pycodecs.py +++ b/Lib/_pycodecs.py @@ -1086,11 +1086,13 @@ def charmapencode_output(c, mapping): rep = mapping[c] if isinstance(rep, int) or isinstance(rep, int): if rep < 256: - return rep + return [rep] else: raise TypeError("character mapping must be in range(256)") elif isinstance(rep, str): - return ord(rep) + return [ord(rep)] + elif isinstance(rep, bytes): + return rep elif rep == None: raise KeyError("character maps to ") else: @@ -1113,12 +1115,13 @@ def PyUnicode_EncodeCharmap(p, size, mapping='latin-1', errors='strict'): #/* try to encode it */ try: x = charmapencode_output(ord(p[inpos]), mapping) - res += [x] + res += x except KeyError: x = unicode_call_errorhandler(errors, "charmap", "character maps to ", p, inpos, inpos+1, False) try: - res += [charmapencode_output(ord(y), mapping) for y in x[0]] + for y in x[0]: + res += charmapencode_output(ord(y), mapping) except KeyError: raise UnicodeEncodeError("charmap", p, inpos, inpos+1, "character maps to ") diff --git a/Lib/test/test_charmapcodec.py b/Lib/test/test_charmapcodec.py index 8ea75d9129..0d4594d8c0 100644 --- a/Lib/test/test_charmapcodec.py +++ b/Lib/test/test_charmapcodec.py @@ -33,8 +33,6 @@ def test_constructorx(self): self.assertEqual(str(b'dxf', codecname), 'dabcf') self.assertEqual(str(b'dxfx', codecname), 'dabcfabc') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encodex(self): self.assertEqual('abc'.encode(codecname), b'abc') self.assertEqual('xdef'.encode(codecname), b'abcdef') diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py index 293b75a866..09a6d883f8 100644 --- a/Lib/test/test_codeccallbacks.py +++ b/Lib/test/test_codeccallbacks.py @@ -203,8 +203,6 @@ def relaxedutf8(exc): self.assertRaises(UnicodeDecodeError, sin.decode, "utf-8", "test.relaxedutf8") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_charmapencode(self): # For charmap encodings the replacement string will be # mapped through the encoding again. This means, that @@ -329,8 +327,6 @@ def check_exceptionobjectargs(self, exctype, args, msg): exc = exctype(*args) self.assertEqual(str(exc), msg) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_unicodeencodeerror(self): self.check_exceptionobjectargs( UnicodeEncodeError, @@ -363,8 +359,6 @@ def test_unicodeencodeerror(self): "'ascii' codec can't encode character '\\U00010000' in position 0: ouch" ) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_unicodedecodeerror(self): self.check_exceptionobjectargs( UnicodeDecodeError, @@ -377,8 +371,6 @@ def test_unicodedecodeerror(self): "'ascii' codec can't decode bytes in position 1-2: ouch" ) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_unicodetranslateerror(self): self.check_exceptionobjectargs( UnicodeTranslateError, @@ -467,8 +459,6 @@ def test_badandgoodignoreexceptions(self): ("", 2) ) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_badandgoodreplaceexceptions(self): # "replace" complains about a non-exception passed in self.assertRaises( @@ -509,8 +499,6 @@ def test_badandgoodreplaceexceptions(self): ("\ufffd", 2) ) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_badandgoodxmlcharrefreplaceexceptions(self): # "xmlcharrefreplace" complains about a non-exception passed in self.assertRaises( @@ -1017,8 +1005,6 @@ def __getitem__(self, key): self.assertRaises(ValueError, codecs.charmap_decode, b"\xff", "strict", D()) self.assertRaises(TypeError, codecs.charmap_decode, b"\xff", "strict", {0xff: sys.maxunicode+1}) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encodehelper(self): # enhance coverage of: # Objects/unicodeobject.c::unicode_encode_call_errorhandler() diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index a570d65f6c..370685f1c6 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -5525,8 +5525,6 @@ def test_encoding_errors_default(self): self.assertEqual(data, r'\U0001f602: \u2603\ufe0f: The \xd8resund ' r'Bridge joins Copenhagen to Malm\xf6') - # TODO: RustPython - @unittest.expectedFailure def test_encoding_errors_none(self): # Specifying None should behave as 'strict' try: diff --git a/common/Cargo.toml b/common/Cargo.toml index e9aeba7459..3ebac23cbf 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -29,6 +29,7 @@ num-traits = { workspace = true } once_cell = { workspace = true } parking_lot = { workspace = true, optional = true } rand = { workspace = true } +unicode_names2 = { workspace = true } lock_api = "0.4" radium = "0.7" diff --git a/common/src/encodings.rs b/common/src/encodings.rs index 7d99646c31..097dae17ba 100644 --- a/common/src/encodings.rs +++ b/common/src/encodings.rs @@ -1,13 +1,9 @@ -use std::ops::Range; +use std::ops::{self, Range}; use num_traits::ToPrimitive; use crate::str::StrKind; -use crate::wtf8::{Wtf8, Wtf8Buf}; - -pub type EncodeErrorResult = Result<(EncodeReplace, usize), E>; - -pub type DecodeErrorResult = Result<(S, Option, usize), E>; +use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf}; pub trait StrBuffer: AsRef { fn is_compatible_with(&self, kind: StrKind) -> bool { @@ -20,28 +16,116 @@ pub trait StrBuffer: AsRef { } } -pub trait ErrorHandler { +pub trait CodecContext: Sized { type Error; type StrBuf: StrBuffer; type BytesBuf: AsRef<[u8]>; + + fn string(&self, s: Wtf8Buf) -> Self::StrBuf; + fn bytes(&self, b: Vec) -> Self::BytesBuf; +} + +pub trait EncodeContext: CodecContext { + fn full_data(&self) -> &Wtf8; + fn data_len(&self) -> StrSize; + + fn remaining_data(&self) -> &Wtf8; + fn position(&self) -> StrSize; + + fn restart_from(&mut self, pos: StrSize) -> Result<(), Self::Error>; + + fn error_encoding(&self, range: Range, reason: Option<&str>) -> Self::Error; + + fn handle_error( + &mut self, + errors: &E, + range: Range, + reason: Option<&str>, + ) -> Result, Self::Error> + where + E: EncodeErrorHandler, + { + let (replace, restart) = errors.handle_encode_error(self, range, reason)?; + self.restart_from(restart)?; + Ok(replace) + } +} + +pub trait DecodeContext: CodecContext { + fn full_data(&self) -> &[u8]; + + fn remaining_data(&self) -> &[u8]; + fn position(&self) -> usize; + + fn advance(&mut self, by: usize); + + fn restart_from(&mut self, pos: usize) -> Result<(), Self::Error>; + + fn error_decoding(&self, byte_range: Range, reason: Option<&str>) -> Self::Error; + + fn handle_error( + &mut self, + errors: &E, + byte_range: Range, + reason: Option<&str>, + ) -> Result + where + E: DecodeErrorHandler, + { + let (replace, restart) = errors.handle_decode_error(self, byte_range, reason)?; + self.restart_from(restart)?; + Ok(replace) + } +} + +pub trait EncodeErrorHandler { fn handle_encode_error( &self, - data: &Wtf8, - char_range: Range, - reason: &str, - ) -> EncodeErrorResult; + ctx: &mut Ctx, + range: Range, + reason: Option<&str>, + ) -> Result<(EncodeReplace, StrSize), Ctx::Error>; +} +pub trait DecodeErrorHandler { fn handle_decode_error( &self, - data: &[u8], + ctx: &mut Ctx, byte_range: Range, - reason: &str, - ) -> DecodeErrorResult; - fn error_oob_restart(&self, i: usize) -> Self::Error; - fn error_encoding(&self, data: &Wtf8, char_range: Range, reason: &str) -> Self::Error; + reason: Option<&str>, + ) -> Result<(Ctx::StrBuf, usize), Ctx::Error>; } -pub enum EncodeReplace { - Str(S), - Bytes(B), + +pub enum EncodeReplace { + Str(Ctx::StrBuf), + Bytes(Ctx::BytesBuf), +} + +#[derive(Copy, Clone, Default, Debug)] +pub struct StrSize { + pub bytes: usize, + pub chars: usize, +} + +fn iter_code_points(w: &Wtf8) -> impl Iterator { + w.code_point_indices() + .enumerate() + .map(|(chars, (bytes, c))| (StrSize { bytes, chars }, c)) +} + +impl ops::Add for StrSize { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self { + bytes: self.bytes + rhs.bytes, + chars: self.chars + rhs.chars, + } + } +} +impl ops::AddAssign for StrSize { + fn add_assign(&mut self, rhs: Self) { + self.bytes += rhs.bytes; + self.chars += rhs.chars; + } } struct DecodeError<'a> { @@ -68,128 +152,341 @@ enum HandleResult<'a> { reason: &'a str, }, } -fn decode_utf8_compatible( - data: &[u8], +fn decode_utf8_compatible( + mut ctx: Ctx, errors: &E, decode: DecodeF, handle_error: ErrF, -) -> Result<(Wtf8Buf, usize), E::Error> +) -> Result<(Wtf8Buf, usize), Ctx::Error> where + Ctx: DecodeContext, + E: DecodeErrorHandler, DecodeF: Fn(&[u8]) -> Result<&str, DecodeError<'_>>, - ErrF: Fn(&[u8], Option) -> HandleResult<'_>, + ErrF: Fn(&[u8], Option) -> HandleResult<'static>, { - if data.is_empty() { + if ctx.remaining_data().is_empty() { return Ok((Wtf8Buf::new(), 0)); } - // we need to coerce the lifetime to that of the function body rather than the - // anonymous input lifetime, so that we can assign it data borrowed from data_from_err - let mut data = data; - let mut data_from_err: E::BytesBuf; - let mut out = Wtf8Buf::with_capacity(data.len()); - let mut remaining_index = 0; - let mut remaining_data = data; + let mut out = Wtf8Buf::with_capacity(ctx.remaining_data().len()); loop { - match decode(remaining_data) { + match decode(ctx.remaining_data()) { Ok(decoded) => { out.push_str(decoded); - remaining_index += decoded.len(); + ctx.advance(decoded.len()); break; } Err(e) => { out.push_str(e.valid_prefix); match handle_error(e.rest, e.err_len) { HandleResult::Done => { - remaining_index += e.valid_prefix.len(); + ctx.advance(e.valid_prefix.len()); break; } HandleResult::Error { err_len, reason } => { - let err_idx = remaining_index + e.valid_prefix.len(); - let err_range = - err_idx..err_len.map_or_else(|| data.len(), |len| err_idx + len); - let (replace, new_data, restart) = - errors.handle_decode_error(data, err_range, reason)?; + let err_start = ctx.position() + e.valid_prefix.len(); + let err_end = match err_len { + Some(len) => err_start + len, + None => ctx.full_data().len(), + }; + let err_range = err_start..err_end; + let replace = ctx.handle_error(errors, err_range, Some(reason))?; out.push_wtf8(replace.as_ref()); - if let Some(new_data) = new_data { - data_from_err = new_data; - data = data_from_err.as_ref(); - } - remaining_data = data - .get(restart..) - .ok_or_else(|| errors.error_oob_restart(restart))?; - remaining_index = restart; continue; } } } } } - Ok((out, remaining_index)) + Ok((out, ctx.position())) } #[inline] -fn encode_utf8_compatible( - s: &Wtf8, +fn encode_utf8_compatible( + mut ctx: Ctx, errors: &E, err_reason: &str, target_kind: StrKind, -) -> Result, E::Error> { - let full_data = s; - let mut data = s; - let mut char_data_index = 0; - let mut out = Vec::::new(); - while let Some((char_i, (byte_i, _))) = data - .code_point_indices() - .enumerate() - .find(|(_, (_, c))| !target_kind.can_encode(*c)) - { - out.extend_from_slice(&data.as_bytes()[..byte_i]); - let char_start = char_data_index + char_i; +) -> Result, Ctx::Error> +where + Ctx: EncodeContext, + E: EncodeErrorHandler, +{ + // let mut data = s.as_ref(); + // let mut char_data_index = 0; + let mut out = Vec::::with_capacity(ctx.remaining_data().len()); + loop { + let data = ctx.remaining_data(); + let mut iter = iter_code_points(data); + let Some((i, _)) = iter.find(|(_, c)| !target_kind.can_encode(*c)) else { + break; + }; + + out.extend_from_slice(&ctx.remaining_data().as_bytes()[..i.bytes]); + let err_start = ctx.position() + i; // number of non-compatible chars between the first non-compatible char and the next compatible char - let non_compat_run_length = data[byte_i..] - .code_points() - .take_while(|c| !target_kind.can_encode(*c)) - .count(); - let char_range = char_start..char_start + non_compat_run_length; - let (replace, char_restart) = - errors.handle_encode_error(full_data, char_range.clone(), err_reason)?; + let err_end = match { iter }.find(|(_, c)| target_kind.can_encode(*c)) { + Some((i, _)) => ctx.position() + i, + None => ctx.data_len(), + }; + + let range = err_start..err_end; + let replace = ctx.handle_error(errors, range.clone(), Some(err_reason))?; match replace { EncodeReplace::Str(s) => { if s.is_compatible_with(target_kind) { out.extend_from_slice(s.as_ref().as_bytes()); } else { - return Err(errors.error_encoding(full_data, char_range, err_reason)); + return Err(ctx.error_encoding(range, Some(err_reason))); } } EncodeReplace::Bytes(b) => { out.extend_from_slice(b.as_ref()); } } - data = crate::str::try_get_codepoints(full_data, char_restart..) - .ok_or_else(|| errors.error_oob_restart(char_restart))?; - char_data_index = char_restart; } - out.extend_from_slice(data.as_bytes()); + out.extend_from_slice(ctx.remaining_data().as_bytes()); Ok(out) } +pub mod errors { + use crate::str::UnicodeEscapeCodepoint; + + use super::*; + use std::fmt::Write; + + pub struct Strict; + + impl EncodeErrorHandler for Strict { + fn handle_encode_error( + &self, + ctx: &mut Ctx, + range: Range, + reason: Option<&str>, + ) -> Result<(EncodeReplace, StrSize), Ctx::Error> { + Err(ctx.error_encoding(range, reason)) + } + } + + impl DecodeErrorHandler for Strict { + fn handle_decode_error( + &self, + ctx: &mut Ctx, + byte_range: Range, + reason: Option<&str>, + ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> { + Err(ctx.error_decoding(byte_range, reason)) + } + } + + pub struct Ignore; + + impl EncodeErrorHandler for Ignore { + fn handle_encode_error( + &self, + ctx: &mut Ctx, + range: Range, + _reason: Option<&str>, + ) -> Result<(EncodeReplace, StrSize), Ctx::Error> { + Ok((EncodeReplace::Bytes(ctx.bytes(b"".into())), range.end)) + } + } + + impl DecodeErrorHandler for Ignore { + fn handle_decode_error( + &self, + ctx: &mut Ctx, + byte_range: Range, + _reason: Option<&str>, + ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> { + Ok((ctx.string("".into()), byte_range.end)) + } + } + + pub struct Replace; + + impl EncodeErrorHandler for Replace { + fn handle_encode_error( + &self, + ctx: &mut Ctx, + range: Range, + _reason: Option<&str>, + ) -> Result<(EncodeReplace, StrSize), Ctx::Error> { + let replace = "?".repeat(range.end.chars - range.start.chars); + Ok((EncodeReplace::Str(ctx.string(replace.into())), range.end)) + } + } + + impl DecodeErrorHandler for Replace { + fn handle_decode_error( + &self, + ctx: &mut Ctx, + byte_range: Range, + _reason: Option<&str>, + ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> { + Ok(( + ctx.string(char::REPLACEMENT_CHARACTER.to_string().into()), + byte_range.end, + )) + } + } + + pub struct XmlCharRefReplace; + + impl EncodeErrorHandler for XmlCharRefReplace { + fn handle_encode_error( + &self, + ctx: &mut Ctx, + range: Range, + _reason: Option<&str>, + ) -> Result<(EncodeReplace, StrSize), Ctx::Error> { + let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes]; + let num_chars = range.end.chars - range.start.chars; + // capacity rough guess; assuming that the codepoints are 3 digits in decimal + the &#; + let mut out = String::with_capacity(num_chars * 6); + for c in err_str.code_points() { + write!(out, "&#{};", c.to_u32()).unwrap() + } + Ok((EncodeReplace::Str(ctx.string(out.into())), range.end)) + } + } + + pub struct BackslashReplace; + + impl EncodeErrorHandler for BackslashReplace { + fn handle_encode_error( + &self, + ctx: &mut Ctx, + range: Range, + _reason: Option<&str>, + ) -> Result<(EncodeReplace, StrSize), Ctx::Error> { + let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes]; + let num_chars = range.end.chars - range.start.chars; + // minimum 4 output bytes per char: \xNN + let mut out = String::with_capacity(num_chars * 4); + for c in err_str.code_points() { + write!(out, "{}", UnicodeEscapeCodepoint(c)).unwrap(); + } + Ok((EncodeReplace::Str(ctx.string(out.into())), range.end)) + } + } + + impl DecodeErrorHandler for BackslashReplace { + fn handle_decode_error( + &self, + ctx: &mut Ctx, + byte_range: Range, + _reason: Option<&str>, + ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> { + let err_bytes = &ctx.full_data()[byte_range.clone()]; + let mut replace = String::with_capacity(4 * err_bytes.len()); + for &c in err_bytes { + write!(replace, "\\x{c:02x}").unwrap(); + } + Ok((ctx.string(replace.into()), byte_range.end)) + } + } + + pub struct NameReplace; + + impl EncodeErrorHandler for NameReplace { + fn handle_encode_error( + &self, + ctx: &mut Ctx, + range: Range, + _reason: Option<&str>, + ) -> Result<(EncodeReplace, StrSize), Ctx::Error> { + let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes]; + let num_chars = range.end.chars - range.start.chars; + let mut out = String::with_capacity(num_chars * 4); + for c in err_str.code_points() { + let c_u32 = c.to_u32(); + if let Some(c_name) = unicode_names2::name(c.to_char_lossy()) { + write!(out, "\\N{{{c_name}}}").unwrap(); + } else if c_u32 >= 0x10000 { + write!(out, "\\U{c_u32:08x}").unwrap(); + } else if c_u32 >= 0x100 { + write!(out, "\\u{c_u32:04x}").unwrap(); + } else { + write!(out, "\\x{c_u32:02x}").unwrap(); + } + } + Ok((EncodeReplace::Str(ctx.string(out.into())), range.end)) + } + } + + pub struct SurrogateEscape; + + impl EncodeErrorHandler for SurrogateEscape { + fn handle_encode_error( + &self, + ctx: &mut Ctx, + range: Range, + reason: Option<&str>, + ) -> Result<(EncodeReplace, StrSize), Ctx::Error> { + let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes]; + let num_chars = range.end.chars - range.start.chars; + let mut out = Vec::with_capacity(num_chars); + for ch in err_str.code_points() { + let ch = ch.to_u32(); + if !(0xdc80..=0xdcff).contains(&ch) { + // Not a UTF-8b surrogate, fail with original exception + return Err(ctx.error_encoding(range, reason)); + } + out.push((ch - 0xdc00) as u8); + } + Ok((EncodeReplace::Bytes(ctx.bytes(out)), range.end)) + } + } + + impl DecodeErrorHandler for SurrogateEscape { + fn handle_decode_error( + &self, + ctx: &mut Ctx, + byte_range: Range, + reason: Option<&str>, + ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> { + let err_bytes = &ctx.full_data()[byte_range.clone()]; + let mut consumed = 0; + let mut replace = Wtf8Buf::with_capacity(4 * byte_range.len()); + while consumed < 4 && consumed < byte_range.len() { + let c = err_bytes[consumed] as u16; + // Refuse to escape ASCII bytes + if c < 128 { + break; + } + replace.push(CodePoint::from(0xdc00 + c)); + consumed += 1; + } + if consumed == 0 { + return Err(ctx.error_decoding(byte_range, reason)); + } + Ok((ctx.string(replace), byte_range.start + consumed)) + } + } +} + pub mod utf8 { use super::*; pub const ENCODING_NAME: &str = "utf-8"; #[inline] - pub fn encode(s: &Wtf8, errors: &E) -> Result, E::Error> { - encode_utf8_compatible(s, errors, "surrogates not allowed", StrKind::Utf8) + pub fn encode(ctx: Ctx, errors: &E) -> Result, Ctx::Error> + where + Ctx: EncodeContext, + E: EncodeErrorHandler, + { + encode_utf8_compatible(ctx, errors, "surrogates not allowed", StrKind::Utf8) } - pub fn decode( - data: &[u8], + pub fn decode>( + ctx: Ctx, errors: &E, final_decode: bool, - ) -> Result<(Wtf8Buf, usize), E::Error> { + ) -> Result<(Wtf8Buf, usize), Ctx::Error> { decode_utf8_compatible( - data, + ctx, errors, |v| { core::str::from_utf8(v).map_err(|e| { @@ -237,67 +534,55 @@ pub mod latin_1 { const ERR_REASON: &str = "ordinal not in range(256)"; #[inline] - pub fn encode(s: &Wtf8, errors: &E) -> Result, E::Error> { - let full_data = s; - let mut data = s; - let mut char_data_index = 0; + pub fn encode(mut ctx: Ctx, errors: &E) -> Result, Ctx::Error> + where + Ctx: EncodeContext, + E: EncodeErrorHandler, + { let mut out = Vec::::new(); loop { - match data - .code_point_indices() - .enumerate() - .find(|(_, (_, c))| !c.is_ascii()) - { - None => { - out.extend_from_slice(data.as_bytes()); - break; - } - Some((char_i, (byte_i, ch))) => { - out.extend_from_slice(&data.as_bytes()[..byte_i]); - let char_start = char_data_index + char_i; - if let Some(byte) = ch.to_u32().to_u8() { - out.push(byte); - // if the codepoint is between 128..=255, it's utf8-length is 2 - data = &data[byte_i + 2..]; - char_data_index = char_start + 1; - } else { - // number of non-latin_1 chars between the first non-latin_1 char and the next latin_1 char - let non_latin_1_run_length = data[byte_i..] - .code_points() - .take_while(|c| c.to_u32() > 255) - .count(); - let char_range = char_start..char_start + non_latin_1_run_length; - let (replace, char_restart) = errors.handle_encode_error( - full_data, - char_range.clone(), - ERR_REASON, - )?; - match replace { - EncodeReplace::Str(s) => { - if s.as_ref().code_points().any(|c| c.to_u32() > 255) { - return Err( - errors.error_encoding(full_data, char_range, ERR_REASON) - ); - } - out.extend_from_slice(s.as_ref().as_bytes()); - } - EncodeReplace::Bytes(b) => { - out.extend_from_slice(b.as_ref()); - } + let data = ctx.remaining_data(); + let mut iter = iter_code_points(ctx.remaining_data()); + let Some((i, ch)) = iter.find(|(_, c)| !c.is_ascii()) else { + break; + }; + out.extend_from_slice(&data.as_bytes()[..i.bytes]); + let err_start = ctx.position() + i; + if let Some(byte) = ch.to_u32().to_u8() { + drop(iter); + out.push(byte); + // if the codepoint is between 128..=255, it's utf8-length is 2 + ctx.restart_from(err_start + StrSize { bytes: 2, chars: 1 })?; + } else { + // number of non-latin_1 chars between the first non-latin_1 char and the next latin_1 char + let err_end = match { iter }.find(|(_, c)| c.to_u32() <= 255) { + Some((i, _)) => ctx.position() + i, + None => ctx.data_len(), + }; + let err_range = err_start..err_end; + let replace = ctx.handle_error(errors, err_range.clone(), Some(ERR_REASON))?; + match replace { + EncodeReplace::Str(s) => { + if s.as_ref().code_points().any(|c| c.to_u32() > 255) { + return Err(ctx.error_encoding(err_range, Some(ERR_REASON))); } - data = crate::str::try_get_codepoints(full_data, char_restart..) - .ok_or_else(|| errors.error_oob_restart(char_restart))?; - char_data_index = char_restart; + out.extend(s.as_ref().code_points().map(|c| c.to_u32() as u8)); + } + EncodeReplace::Bytes(b) => { + out.extend_from_slice(b.as_ref()); } - continue; } } } + out.extend_from_slice(ctx.remaining_data().as_bytes()); Ok(out) } - pub fn decode(data: &[u8], _errors: &E) -> Result<(Wtf8Buf, usize), E::Error> { - let out: String = data.iter().map(|c| *c as char).collect(); + pub fn decode>( + ctx: Ctx, + _errors: &E, + ) -> Result<(Wtf8Buf, usize), Ctx::Error> { + let out: String = ctx.remaining_data().iter().map(|c| *c as char).collect(); let out_len = out.len(); Ok((out.into(), out_len)) } @@ -312,13 +597,20 @@ pub mod ascii { const ERR_REASON: &str = "ordinal not in range(128)"; #[inline] - pub fn encode(s: &Wtf8, errors: &E) -> Result, E::Error> { - encode_utf8_compatible(s, errors, ERR_REASON, StrKind::Ascii) + pub fn encode(ctx: Ctx, errors: &E) -> Result, Ctx::Error> + where + Ctx: EncodeContext, + E: EncodeErrorHandler, + { + encode_utf8_compatible(ctx, errors, ERR_REASON, StrKind::Ascii) } - pub fn decode(data: &[u8], errors: &E) -> Result<(Wtf8Buf, usize), E::Error> { + pub fn decode>( + ctx: Ctx, + errors: &E, + ) -> Result<(Wtf8Buf, usize), Ctx::Error> { decode_utf8_compatible( - data, + ctx, errors, |v| { AsciiStr::from_ascii(v).map(|s| s.as_str()).map_err(|e| { diff --git a/common/src/str.rs b/common/src/str.rs index 176b5d0f87..8a00dcf1d8 100644 --- a/common/src/str.rs +++ b/common/src/str.rs @@ -446,6 +446,21 @@ pub fn to_ascii(value: &str) -> AsciiString { unsafe { AsciiString::from_ascii_unchecked(ascii) } } +pub struct UnicodeEscapeCodepoint(pub CodePoint); + +impl fmt::Display for UnicodeEscapeCodepoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let c = self.0.to_u32(); + if c >= 0x10000 { + write!(f, "\\U{c:08x}") + } else if c >= 0x100 { + write!(f, "\\u{c:04x}") + } else { + write!(f, "\\x{c:02x}") + } + } +} + pub mod levenshtein { use std::{cell::RefCell, thread_local}; diff --git a/common/src/wtf8/mod.rs b/common/src/wtf8/mod.rs index f209b98a3e..f6ae628bad 100644 --- a/common/src/wtf8/mod.rs +++ b/common/src/wtf8/mod.rs @@ -618,6 +618,9 @@ impl ToOwned for Wtf8 { fn to_owned(&self) -> Self::Owned { self.to_wtf8_buf() } + fn clone_into(&self, buf: &mut Self::Owned) { + self.bytes.clone_into(&mut buf.bytes); + } } impl PartialEq for Wtf8 { @@ -872,8 +875,8 @@ impl Wtf8 { } } - pub fn clone_into(&self, buf: &mut Wtf8Buf) { - self.bytes.clone_into(&mut buf.bytes); + pub fn is_code_point_boundary(&self, index: usize) -> bool { + is_code_point_boundary(self, index) } /// Boxes this `Wtf8`. @@ -1101,6 +1104,7 @@ impl ops::Index> for Wtf8 { type Output = Wtf8; #[inline] + #[track_caller] fn index(&self, range: ops::Range) -> &Wtf8 { // is_code_point_boundary checks that the index is in [0, .len()] if range.start <= range.end @@ -1124,6 +1128,7 @@ impl ops::Index> for Wtf8 { type Output = Wtf8; #[inline] + #[track_caller] fn index(&self, range: ops::RangeFrom) -> &Wtf8 { // is_code_point_boundary checks that the index is in [0, .len()] if is_code_point_boundary(self, range.start) { @@ -1144,6 +1149,7 @@ impl ops::Index> for Wtf8 { type Output = Wtf8; #[inline] + #[track_caller] fn index(&self, range: ops::RangeTo) -> &Wtf8 { // is_code_point_boundary checks that the index is in [0, .len()] if is_code_point_boundary(self, range.end) { @@ -1171,7 +1177,7 @@ fn decode_surrogate(second_byte: u8, third_byte: u8) -> u16 { /// Copied from str::is_char_boundary #[inline] -pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool { +fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool { if index == 0 { return true; } @@ -1226,6 +1232,7 @@ pub unsafe fn slice_unchecked(s: &Wtf8, begin: usize, end: usize) -> &Wtf8 { /// Copied from core::str::raw::slice_error_fail #[inline(never)] +#[track_caller] pub fn slice_error_fail(s: &Wtf8, begin: usize, end: usize) -> ! { assert!(begin <= end); panic!("index {begin} and/or {end} in `{s:?}` do not lie on character boundary"); diff --git a/vm/src/codecs.rs b/vm/src/codecs.rs index bdb9b4b809..8d002916a6 100644 --- a/vm/src/codecs.rs +++ b/vm/src/codecs.rs @@ -1,13 +1,27 @@ -use rustpython_common::wtf8::{CodePoint, Wtf8Buf}; +use rustpython_common::{ + borrow::BorrowedValue, + encodings::{ + CodecContext, DecodeContext, DecodeErrorHandler, EncodeContext, EncodeErrorHandler, + EncodeReplace, StrBuffer, StrSize, errors, + }, + str::StrKind, + wtf8::{CodePoint, Wtf8, Wtf8Buf}, +}; use crate::{ - AsObject, Context, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, - builtins::{PyBaseExceptionRef, PyBytesRef, PyStr, PyStrRef, PyTuple, PyTupleRef}, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, + TryFromObject, VirtualMachine, + builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyStr, PyStrRef, PyTuple, PyTupleRef}, common::{ascii, lock::PyRwLock}, convert::ToPyObject, - function::PyMethodDef, + function::{ArgBytesLike, PyMethodDef}, +}; +use once_cell::unsync::OnceCell; +use std::{ + borrow::Cow, + collections::HashMap, + ops::{self, Range}, }; -use std::{borrow::Cow, collections::HashMap, fmt::Write, ops::Range}; pub struct CodecsRegistry { inner: PyRwLock, @@ -360,6 +374,676 @@ fn normalize_encoding_name(encoding: &str) -> Cow<'_, str> { } } +#[derive(Eq, PartialEq)] +enum StandardEncoding { + Utf8, + Utf16Be, + Utf16Le, + Utf32Be, + Utf32Le, +} + +impl StandardEncoding { + #[cfg(target_endian = "little")] + const UTF_16_NE: Self = Self::Utf16Le; + #[cfg(target_endian = "big")] + const UTF_16_NE: Self = Self::Utf16Be; + + #[cfg(target_endian = "little")] + const UTF_32_NE: Self = Self::Utf32Le; + #[cfg(target_endian = "big")] + const UTF_32_NE: Self = Self::Utf32Be; + + fn parse(encoding: &str) -> Option { + if let Some(encoding) = encoding.to_lowercase().strip_prefix("utf") { + let encoding = encoding + .strip_prefix(|c| ['-', '_'].contains(&c)) + .unwrap_or(encoding); + if encoding == "8" { + Some(Self::Utf8) + } else if let Some(encoding) = encoding.strip_prefix("16") { + if encoding.is_empty() { + return Some(Self::UTF_16_NE); + } + let encoding = encoding.strip_prefix(['-', '_']).unwrap_or(encoding); + match encoding { + "be" => Some(Self::Utf16Be), + "le" => Some(Self::Utf16Le), + _ => None, + } + } else if let Some(encoding) = encoding.strip_prefix("32") { + if encoding.is_empty() { + return Some(Self::UTF_32_NE); + } + let encoding = encoding.strip_prefix(['-', '_']).unwrap_or(encoding); + match encoding { + "be" => Some(Self::Utf32Be), + "le" => Some(Self::Utf32Le), + _ => return None, + } + } else { + None + } + } else if encoding == "CP_UTF8" { + Some(Self::Utf8) + } else { + None + } + } +} + +struct SurrogatePass; + +impl<'a> EncodeErrorHandler> for SurrogatePass { + fn handle_encode_error( + &self, + ctx: &mut PyEncodeContext<'a>, + range: Range, + reason: Option<&str>, + ) -> PyResult<(EncodeReplace>, StrSize)> { + let standard_encoding = StandardEncoding::parse(ctx.encoding) + .ok_or_else(|| ctx.error_encoding(range.clone(), reason))?; + let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes]; + let num_chars = range.end.chars - range.start.chars; + let mut out: Vec = Vec::with_capacity(num_chars * 4); + for ch in err_str.code_points() { + let c = ch.to_u32(); + let 0xd800..=0xdfff = c else { + // Not a surrogate, fail with original exception + return Err(ctx.error_encoding(range, reason)); + }; + match standard_encoding { + StandardEncoding::Utf8 => out.extend(ch.encode_wtf8(&mut [0; 4]).as_bytes()), + StandardEncoding::Utf16Le => out.extend((c as u16).to_le_bytes()), + StandardEncoding::Utf16Be => out.extend((c as u16).to_be_bytes()), + StandardEncoding::Utf32Le => out.extend(c.to_le_bytes()), + StandardEncoding::Utf32Be => out.extend(c.to_be_bytes()), + } + } + Ok((EncodeReplace::Bytes(ctx.bytes(out)), range.end)) + } +} + +impl<'a> DecodeErrorHandler> for SurrogatePass { + fn handle_decode_error( + &self, + ctx: &mut PyDecodeContext<'a>, + byte_range: Range, + reason: Option<&str>, + ) -> PyResult<(PyStrRef, usize)> { + let standard_encoding = StandardEncoding::parse(ctx.encoding) + .ok_or_else(|| ctx.error_decoding(byte_range.clone(), reason))?; + + let s = ctx.full_data(); + debug_assert!(byte_range.start <= 0.max(s.len() - 1)); + debug_assert!(byte_range.end >= 1.min(s.len())); + debug_assert!(byte_range.end <= s.len()); + + // Try decoding a single surrogate character. If there are more, + // let the codec call us again. + let p = &s[byte_range.start..]; + + fn slice(p: &[u8]) -> Option<[u8; N]> { + p.get(..N).map(|x| x.try_into().unwrap()) + } + + let c = match standard_encoding { + StandardEncoding::Utf8 => { + // it's a three-byte code + slice::<3>(p) + .filter(|&[a, b, c]| { + (u32::from(a) & 0xf0) == 0xe0 + && (u32::from(b) & 0xc0) == 0x80 + && (u32::from(c) & 0xc0) == 0x80 + }) + .map(|[a, b, c]| { + ((u32::from(a) & 0x0f) << 12) + + ((u32::from(b) & 0x3f) << 6) + + (u32::from(c) & 0x3f) + }) + } + StandardEncoding::Utf16Le => slice(p).map(u16::from_le_bytes).map(u32::from), + StandardEncoding::Utf16Be => slice(p).map(u16::from_be_bytes).map(u32::from), + StandardEncoding::Utf32Le => slice(p).map(u32::from_le_bytes), + StandardEncoding::Utf32Be => slice(p).map(u32::from_be_bytes), + }; + let byte_length = match standard_encoding { + StandardEncoding::Utf8 => 3, + StandardEncoding::Utf16Be | StandardEncoding::Utf16Le => 2, + StandardEncoding::Utf32Be | StandardEncoding::Utf32Le => 4, + }; + + // !Py_UNICODE_IS_SURROGATE + let c = c + .and_then(CodePoint::from_u32) + .filter(|c| matches!(c.to_u32(), 0xd800..=0xdfff)) + .ok_or_else(|| ctx.error_decoding(byte_range.clone(), reason))?; + + Ok((ctx.string(c.into()), byte_range.start + byte_length)) + } +} + +pub struct PyEncodeContext<'a> { + vm: &'a VirtualMachine, + encoding: &'a str, + data: &'a Py, + pos: StrSize, + exception: OnceCell, +} + +impl<'a> PyEncodeContext<'a> { + pub fn new(encoding: &'a str, data: &'a Py, vm: &'a VirtualMachine) -> Self { + Self { + vm, + encoding, + data, + pos: StrSize::default(), + exception: OnceCell::new(), + } + } +} + +impl CodecContext for PyEncodeContext<'_> { + type Error = PyBaseExceptionRef; + type StrBuf = PyStrRef; + type BytesBuf = PyBytesRef; + + fn string(&self, s: Wtf8Buf) -> Self::StrBuf { + self.vm.ctx.new_str(s) + } + + fn bytes(&self, b: Vec) -> Self::BytesBuf { + self.vm.ctx.new_bytes(b) + } +} +impl EncodeContext for PyEncodeContext<'_> { + fn full_data(&self) -> &Wtf8 { + self.data.as_wtf8() + } + + fn data_len(&self) -> StrSize { + StrSize { + bytes: self.data.byte_len(), + chars: self.data.char_len(), + } + } + + fn remaining_data(&self) -> &Wtf8 { + &self.full_data()[self.pos.bytes..] + } + + fn position(&self) -> StrSize { + self.pos + } + + fn restart_from(&mut self, pos: StrSize) -> Result<(), Self::Error> { + if pos.chars > self.data.char_len() { + return Err(self.vm.new_index_error(format!( + "position {} from error handler out of bounds", + pos.chars + ))); + } + assert!( + self.data.as_wtf8().is_code_point_boundary(pos.bytes), + "invalid pos {pos:?} for {:?}", + self.data.as_wtf8() + ); + self.pos = pos; + Ok(()) + } + + fn error_encoding(&self, range: Range, reason: Option<&str>) -> Self::Error { + let vm = self.vm; + match self.exception.get() { + Some(exc) => { + match update_unicode_error_attrs( + exc.as_object(), + range.start.chars, + range.end.chars, + reason, + vm, + ) { + Ok(()) => exc.clone(), + Err(e) => e, + } + } + None => self + .exception + .get_or_init(|| { + let reason = reason.expect( + "should only ever pass reason: None if an exception is already set", + ); + vm.new_unicode_encode_error_real( + vm.ctx.new_str(self.encoding), + self.data.to_owned(), + range.start.chars, + range.end.chars, + vm.ctx.new_str(reason), + ) + }) + .clone(), + } + } +} + +pub struct PyDecodeContext<'a> { + vm: &'a VirtualMachine, + encoding: &'a str, + data: PyDecodeData<'a>, + orig_bytes: Option<&'a Py>, + pos: usize, + exception: OnceCell, +} +enum PyDecodeData<'a> { + Original(BorrowedValue<'a, [u8]>), + Modified(PyBytesRef), +} +impl ops::Deref for PyDecodeData<'_> { + type Target = [u8]; + fn deref(&self) -> &Self::Target { + match self { + PyDecodeData::Original(data) => data, + PyDecodeData::Modified(data) => data, + } + } +} + +impl<'a> PyDecodeContext<'a> { + pub fn new(encoding: &'a str, data: &'a ArgBytesLike, vm: &'a VirtualMachine) -> Self { + Self { + vm, + encoding, + data: PyDecodeData::Original(data.borrow_buf()), + orig_bytes: data.as_object().downcast_ref(), + pos: 0, + exception: OnceCell::new(), + } + } +} + +impl CodecContext for PyDecodeContext<'_> { + type Error = PyBaseExceptionRef; + type StrBuf = PyStrRef; + type BytesBuf = PyBytesRef; + + fn string(&self, s: Wtf8Buf) -> Self::StrBuf { + self.vm.ctx.new_str(s) + } + + fn bytes(&self, b: Vec) -> Self::BytesBuf { + self.vm.ctx.new_bytes(b) + } +} +impl DecodeContext for PyDecodeContext<'_> { + fn full_data(&self) -> &[u8] { + &self.data + } + + fn remaining_data(&self) -> &[u8] { + &self.data[self.pos..] + } + + fn position(&self) -> usize { + self.pos + } + + fn advance(&mut self, by: usize) { + self.pos += by; + } + + fn restart_from(&mut self, pos: usize) -> Result<(), Self::Error> { + if pos > self.data.len() { + return Err(self + .vm + .new_index_error(format!("position {pos} from error handler out of bounds",))); + } + self.pos = pos; + Ok(()) + } + + fn error_decoding(&self, byte_range: Range, reason: Option<&str>) -> Self::Error { + let vm = self.vm; + + match self.exception.get() { + Some(exc) => { + match update_unicode_error_attrs( + exc.as_object(), + byte_range.start, + byte_range.end, + reason, + vm, + ) { + Ok(()) => exc.clone(), + Err(e) => e, + } + } + None => self + .exception + .get_or_init(|| { + let reason = reason.expect( + "should only ever pass reason: None if an exception is already set", + ); + let data = if let Some(bytes) = self.orig_bytes { + bytes.to_owned() + } else { + vm.ctx.new_bytes(self.data.to_vec()) + }; + vm.new_unicode_decode_error_real( + vm.ctx.new_str(self.encoding), + data, + byte_range.start, + byte_range.end, + vm.ctx.new_str(reason), + ) + }) + .clone(), + } + } +} + +#[derive(strum_macros::EnumString)] +#[strum(serialize_all = "lowercase")] +enum StandardError { + Strict, + Ignore, + Replace, + XmlCharRefReplace, + BackslashReplace, + SurrogatePass, + SurrogateEscape, +} + +impl<'a> EncodeErrorHandler> for StandardError { + fn handle_encode_error( + &self, + ctx: &mut PyEncodeContext<'a>, + range: Range, + reason: Option<&str>, + ) -> PyResult<(EncodeReplace>, StrSize)> { + use StandardError::*; + // use errors::*; + match self { + Strict => errors::Strict.handle_encode_error(ctx, range, reason), + Ignore => errors::Ignore.handle_encode_error(ctx, range, reason), + Replace => errors::Replace.handle_encode_error(ctx, range, reason), + XmlCharRefReplace => errors::XmlCharRefReplace.handle_encode_error(ctx, range, reason), + BackslashReplace => errors::BackslashReplace.handle_encode_error(ctx, range, reason), + SurrogatePass => SurrogatePass.handle_encode_error(ctx, range, reason), + SurrogateEscape => errors::SurrogateEscape.handle_encode_error(ctx, range, reason), + } + } +} + +impl<'a> DecodeErrorHandler> for StandardError { + fn handle_decode_error( + &self, + ctx: &mut PyDecodeContext<'a>, + byte_range: Range, + reason: Option<&str>, + ) -> PyResult<(PyStrRef, usize)> { + use StandardError::*; + match self { + Strict => errors::Strict.handle_decode_error(ctx, byte_range, reason), + Ignore => errors::Ignore.handle_decode_error(ctx, byte_range, reason), + Replace => errors::Replace.handle_decode_error(ctx, byte_range, reason), + XmlCharRefReplace => Err(ctx.vm.new_type_error( + "don't know how to handle UnicodeDecodeError in error callback".to_owned(), + )), + BackslashReplace => { + errors::BackslashReplace.handle_decode_error(ctx, byte_range, reason) + } + SurrogatePass => self::SurrogatePass.handle_decode_error(ctx, byte_range, reason), + SurrogateEscape => errors::SurrogateEscape.handle_decode_error(ctx, byte_range, reason), + } + } +} + +pub struct ErrorsHandler<'a> { + errors: &'a Py, + resolved: OnceCell, +} +enum ResolvedError { + Standard(StandardError), + Handler(PyObjectRef), +} + +impl<'a> ErrorsHandler<'a> { + #[inline] + pub fn new(errors: Option<&'a Py>, vm: &VirtualMachine) -> Self { + match errors { + Some(errors) => Self { + errors, + resolved: OnceCell::new(), + }, + None => Self { + errors: identifier!(vm, strict).as_ref(), + resolved: OnceCell::with_value(ResolvedError::Standard(StandardError::Strict)), + }, + } + } + #[inline] + fn resolve(&self, vm: &VirtualMachine) -> PyResult<&ResolvedError> { + self.resolved.get_or_try_init(|| { + if let Ok(standard) = self.errors.as_str().parse() { + Ok(ResolvedError::Standard(standard)) + } else { + vm.state + .codec_registry + .lookup_error(self.errors.as_str(), vm) + .map(ResolvedError::Handler) + } + }) + } +} +impl StrBuffer for PyStrRef { + fn is_compatible_with(&self, kind: StrKind) -> bool { + self.kind() <= kind + } +} +impl<'a> EncodeErrorHandler> for ErrorsHandler<'_> { + fn handle_encode_error( + &self, + ctx: &mut PyEncodeContext<'a>, + range: Range, + reason: Option<&str>, + ) -> PyResult<(EncodeReplace>, StrSize)> { + let vm = ctx.vm; + let handler = match self.resolve(vm)? { + ResolvedError::Standard(standard) => { + return standard.handle_encode_error(ctx, range, reason); + } + ResolvedError::Handler(handler) => handler, + }; + let encode_exc = ctx.error_encoding(range.clone(), reason); + let res = handler.call((encode_exc.clone(),), vm)?; + let tuple_err = || { + vm.new_type_error( + "encoding error handler must return (str/bytes, int) tuple".to_owned(), + ) + }; + let (replace, restart) = match res.payload::().map(|tup| tup.as_slice()) { + Some([replace, restart]) => (replace.clone(), restart), + _ => return Err(tuple_err()), + }; + let replace = match_class!(match replace { + s @ PyStr => EncodeReplace::Str(s), + b @ PyBytes => EncodeReplace::Bytes(b), + _ => return Err(tuple_err()), + }); + let restart = isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?; + let restart = if restart < 0 { + // will still be out of bounds if it underflows ¯\_(ツ)_/¯ + ctx.data.char_len().wrapping_sub(restart.unsigned_abs()) + } else { + restart as usize + }; + let restart = if restart == range.end.chars { + range.end + } else { + StrSize { + chars: restart, + bytes: ctx + .data + .as_wtf8() + .code_point_indices() + .nth(restart) + .map_or(ctx.data.byte_len(), |(i, _)| i), + } + }; + Ok((replace, restart)) + } +} +impl<'a> DecodeErrorHandler> for ErrorsHandler<'_> { + fn handle_decode_error( + &self, + ctx: &mut PyDecodeContext<'a>, + byte_range: Range, + reason: Option<&str>, + ) -> PyResult<(PyStrRef, usize)> { + let vm = ctx.vm; + let handler = match self.resolve(vm)? { + ResolvedError::Standard(standard) => { + return standard.handle_decode_error(ctx, byte_range, reason); + } + ResolvedError::Handler(handler) => handler, + }; + let decode_exc = ctx.error_decoding(byte_range.clone(), reason); + let data_bytes: PyObjectRef = decode_exc.as_object().get_attr("object", vm)?; + let res = handler.call((decode_exc.clone(),), vm)?; + let new_data = decode_exc.as_object().get_attr("object", vm)?; + if !new_data.is(&data_bytes) { + let new_data: PyBytesRef = new_data + .downcast() + .map_err(|_| vm.new_type_error("object attribute must be bytes".to_owned()))?; + ctx.data = PyDecodeData::Modified(new_data); + } + let data = &*ctx.data; + let tuple_err = + || vm.new_type_error("decoding error handler must return (str, int) tuple".to_owned()); + match res.payload::().map(|tup| tup.as_slice()) { + Some([replace, restart]) => { + let replace = replace + .downcast_ref::() + .ok_or_else(tuple_err)? + .to_owned(); + let restart = + isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?; + let restart = if restart < 0 { + // will still be out of bounds if it underflows ¯\_(ツ)_/¯ + data.len().wrapping_sub(restart.unsigned_abs()) + } else { + restart as usize + }; + Ok((replace, restart)) + } + _ => Err(tuple_err()), + } + } +} + +fn call_native_encode_error( + handler: E, + err: PyObjectRef, + vm: &VirtualMachine, +) -> PyResult<(PyObjectRef, usize)> +where + for<'a> E: EncodeErrorHandler>, +{ + // let err = err. + let range = extract_unicode_error_range(&err, vm)?; + let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; + let s_encoding = PyStrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?; + let mut ctx = PyEncodeContext { + vm, + encoding: s_encoding.as_str(), + data: &s, + pos: StrSize::default(), + exception: OnceCell::with_value(err.downcast().unwrap()), + }; + let mut iter = s.as_wtf8().code_point_indices(); + let start = StrSize { + chars: range.start, + bytes: iter.nth(range.start).unwrap().0, + }; + let end = StrSize { + chars: range.end, + bytes: if let Some(n) = range.len().checked_sub(1) { + iter.nth(n).map_or(s.byte_len(), |(i, _)| i) + } else { + start.bytes + }, + }; + let (replace, restart) = handler.handle_encode_error(&mut ctx, start..end, None)?; + let replace = match replace { + EncodeReplace::Str(s) => s.into(), + EncodeReplace::Bytes(b) => b.into(), + }; + Ok((replace, restart.chars)) +} + +fn call_native_decode_error( + handler: E, + err: PyObjectRef, + vm: &VirtualMachine, +) -> PyResult<(PyObjectRef, usize)> +where + for<'a> E: DecodeErrorHandler>, +{ + let range = extract_unicode_error_range(&err, vm)?; + let s = ArgBytesLike::try_from_object(vm, err.get_attr("object", vm)?)?; + let s_encoding = PyStrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?; + let mut ctx = PyDecodeContext { + vm, + encoding: s_encoding.as_str(), + data: PyDecodeData::Original(s.borrow_buf()), + orig_bytes: s.as_object().downcast_ref(), + pos: 0, + exception: OnceCell::with_value(err.downcast().unwrap()), + }; + let (replace, restart) = handler.handle_decode_error(&mut ctx, range, None)?; + Ok((replace.into(), restart)) +} + +// this is a hack, for now +fn call_native_translate_error( + handler: E, + err: PyObjectRef, + vm: &VirtualMachine, +) -> PyResult<(PyObjectRef, usize)> +where + for<'a> E: EncodeErrorHandler>, +{ + // let err = err. + let range = extract_unicode_error_range(&err, vm)?; + let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; + let mut ctx = PyEncodeContext { + vm, + encoding: "", + data: &s, + pos: StrSize::default(), + exception: OnceCell::with_value(err.downcast().unwrap()), + }; + let mut iter = s.as_wtf8().code_point_indices(); + let start = StrSize { + chars: range.start, + bytes: iter.nth(range.start).unwrap().0, + }; + let end = StrSize { + chars: range.end, + bytes: if let Some(n) = range.len().checked_sub(1) { + iter.nth(n).map_or(s.byte_len(), |(i, _)| i) + } else { + start.bytes + }, + }; + let (replace, restart) = handler.handle_encode_error(&mut ctx, start..end, None)?; + let replace = match replace { + EncodeReplace::Str(s) => s.into(), + EncodeReplace::Bytes(b) => b.into(), + }; + Ok((replace, restart.chars)) +} + // TODO: exceptions with custom payloads fn extract_unicode_error_range(err: &PyObject, vm: &VirtualMachine) -> PyResult> { let start = err.get_attr("start", vm)?; @@ -369,14 +1053,32 @@ fn extract_unicode_error_range(err: &PyObject, vm: &VirtualMachine) -> PyResult< Ok(Range { start, end }) } +fn update_unicode_error_attrs( + err: &PyObject, + start: usize, + end: usize, + reason: Option<&str>, + vm: &VirtualMachine, +) -> PyResult<()> { + err.set_attr("start", start.to_pyobject(vm), vm)?; + err.set_attr("end", end.to_pyobject(vm), vm)?; + if let Some(reason) = reason { + err.set_attr("reason", reason.to_pyobject(vm), vm)?; + } + Ok(()) +} + +#[inline] +fn is_encode_err(err: &PyObject, vm: &VirtualMachine) -> bool { + err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) +} #[inline] fn is_decode_err(err: &PyObject, vm: &VirtualMachine) -> bool { err.fast_isinstance(vm.ctx.exceptions.unicode_decode_error) } #[inline] -fn is_encode_ish_err(err: &PyObject, vm: &VirtualMachine) -> bool { - err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) - || err.fast_isinstance(vm.ctx.exceptions.unicode_translate_error) +fn is_translate_err(err: &PyObject, vm: &VirtualMachine) -> bool { + err.fast_isinstance(vm.ctx.exceptions.unicode_translate_error) } fn bad_err_type(err: PyObjectRef, vm: &VirtualMachine) -> PyBaseExceptionRef { @@ -394,7 +1096,7 @@ fn strict_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult { } fn ignore_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> { - if is_encode_ish_err(&err, vm) || is_decode_err(&err, vm) { + if is_encode_err(&err, vm) || is_decode_err(&err, vm) || is_translate_err(&err, vm) { let range = extract_unicode_error_range(&err, vm)?; Ok((vm.ctx.new_str(ascii!("")).into(), range.end)) } else { @@ -402,325 +1104,71 @@ fn ignore_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef } } -fn replace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String, usize)> { - // char::REPLACEMENT_CHARACTER as a str - let replacement_char = "\u{FFFD}"; - let replace = if err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) { - "?" - } else if err.fast_isinstance(vm.ctx.exceptions.unicode_decode_error) { +fn replace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> { + if is_encode_err(&err, vm) { + call_native_encode_error(errors::Replace, err, vm) + } else if is_decode_err(&err, vm) { + call_native_decode_error(errors::Replace, err, vm) + } else if is_translate_err(&err, vm) { + // char::REPLACEMENT_CHARACTER as a str + let replacement_char = "\u{FFFD}"; let range = extract_unicode_error_range(&err, vm)?; - return Ok((replacement_char.to_owned(), range.end)); - } else if err.fast_isinstance(vm.ctx.exceptions.unicode_translate_error) { - replacement_char + let replace = replacement_char.repeat(range.end - range.start); + Ok((replace.to_pyobject(vm), range.end)) } else { return Err(bad_err_type(err, vm)); - }; - let range = extract_unicode_error_range(&err, vm)?; - let replace = replace.repeat(range.end - range.start); - Ok((replace, range.end)) -} - -fn xmlcharrefreplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String, usize)> { - if !is_encode_ish_err(&err, vm) { - return Err(bad_err_type(err, vm)); - } - let range = extract_unicode_error_range(&err, vm)?; - let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let s_after_start = - crate::common::str::try_get_codepoints(s.as_wtf8(), range.start..).unwrap_or_default(); - let num_chars = range.len(); - // capacity rough guess; assuming that the codepoints are 3 digits in decimal + the &#; - let mut out = String::with_capacity(num_chars * 6); - for c in s_after_start.code_points().take(num_chars) { - write!(out, "&#{};", c.to_u32()).unwrap() } - Ok((out, range.end)) } -fn backslashreplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String, usize)> { - if is_decode_err(&err, vm) { - let range = extract_unicode_error_range(&err, vm)?; - let b = PyBytesRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let mut replace = String::with_capacity(4 * range.len()); - for &c in &b[range.clone()] { - write!(replace, "\\x{c:02x}").unwrap(); - } - return Ok((replace, range.end)); - } else if !is_encode_ish_err(&err, vm) { - return Err(bad_err_type(err, vm)); - } - let range = extract_unicode_error_range(&err, vm)?; - let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let s_after_start = - crate::common::str::try_get_codepoints(s.as_wtf8(), range.start..).unwrap_or_default(); - let num_chars = range.len(); - // minimum 4 output bytes per char: \xNN - let mut out = String::with_capacity(num_chars * 4); - for c in s_after_start.code_points().take(num_chars) { - let c = c.to_u32(); - if c >= 0x10000 { - write!(out, "\\U{c:08x}").unwrap(); - } else if c >= 0x100 { - write!(out, "\\u{c:04x}").unwrap(); - } else { - write!(out, "\\x{c:02x}").unwrap(); - } +fn xmlcharrefreplace_errors( + err: PyObjectRef, + vm: &VirtualMachine, +) -> PyResult<(PyObjectRef, usize)> { + if is_encode_err(&err, vm) { + call_native_encode_error(errors::XmlCharRefReplace, err, vm) + } else { + Err(bad_err_type(err, vm)) } - Ok((out, range.end)) } -fn namereplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String, usize)> { - if err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) { - let range = extract_unicode_error_range(&err, vm)?; - let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let s_after_start = - crate::common::str::try_get_codepoints(s.as_wtf8(), range.start..).unwrap_or_default(); - let num_chars = range.len(); - let mut out = String::with_capacity(num_chars * 4); - for c in s_after_start.code_points().take(num_chars) { - let c_u32 = c.to_u32(); - if let Some(c_name) = unicode_names2::name(c.to_char_lossy()) { - write!(out, "\\N{{{c_name}}}").unwrap(); - } else if c_u32 >= 0x10000 { - write!(out, "\\U{c_u32:08x}").unwrap(); - } else if c_u32 >= 0x100 { - write!(out, "\\u{c_u32:04x}").unwrap(); - } else { - write!(out, "\\x{c_u32:02x}").unwrap(); - } - } - Ok((out, range.end)) +fn backslashreplace_errors( + err: PyObjectRef, + vm: &VirtualMachine, +) -> PyResult<(PyObjectRef, usize)> { + if is_decode_err(&err, vm) { + call_native_decode_error(errors::BackslashReplace, err, vm) + } else if is_encode_err(&err, vm) { + call_native_encode_error(errors::BackslashReplace, err, vm) + } else if is_translate_err(&err, vm) { + call_native_translate_error(errors::BackslashReplace, err, vm) } else { Err(bad_err_type(err, vm)) } } -#[derive(Eq, PartialEq)] -enum StandardEncoding { - Utf8, - Utf16Be, - Utf16Le, - Utf32Be, - Utf32Le, - Unknown, -} - -fn get_standard_encoding(encoding: &str) -> (usize, StandardEncoding) { - if let Some(encoding) = encoding.to_lowercase().strip_prefix("utf") { - let mut byte_length: usize = 0; - let mut standard_encoding = StandardEncoding::Unknown; - let encoding = encoding - .strip_prefix(|c| ['-', '_'].contains(&c)) - .unwrap_or(encoding); - if encoding == "8" { - byte_length = 3; - standard_encoding = StandardEncoding::Utf8; - } else if let Some(encoding) = encoding.strip_prefix("16") { - byte_length = 2; - if encoding.is_empty() { - if cfg!(target_endian = "little") { - standard_encoding = StandardEncoding::Utf16Le; - } else if cfg!(target_endian = "big") { - standard_encoding = StandardEncoding::Utf16Be; - } - if standard_encoding != StandardEncoding::Unknown { - return (byte_length, standard_encoding); - } - } - let encoding = encoding - .strip_prefix(|c| ['-', '_'].contains(&c)) - .unwrap_or(encoding); - standard_encoding = match encoding { - "be" => StandardEncoding::Utf16Be, - "le" => StandardEncoding::Utf16Le, - _ => StandardEncoding::Unknown, - } - } else if let Some(encoding) = encoding.strip_prefix("32") { - byte_length = 4; - if encoding.is_empty() { - if cfg!(target_endian = "little") { - standard_encoding = StandardEncoding::Utf32Le; - } else if cfg!(target_endian = "big") { - standard_encoding = StandardEncoding::Utf32Be; - } - if standard_encoding != StandardEncoding::Unknown { - return (byte_length, standard_encoding); - } - } - let encoding = encoding - .strip_prefix(|c| ['-', '_'].contains(&c)) - .unwrap_or(encoding); - standard_encoding = match encoding { - "be" => StandardEncoding::Utf32Be, - "le" => StandardEncoding::Utf32Le, - _ => StandardEncoding::Unknown, - } - } - return (byte_length, standard_encoding); - } else if encoding == "CP_UTF8" { - return (3, StandardEncoding::Utf8); +fn namereplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> { + if is_encode_err(&err, vm) { + call_native_encode_error(errors::NameReplace, err, vm) + } else { + Err(bad_err_type(err, vm)) } - (0, StandardEncoding::Unknown) } fn surrogatepass_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> { - if err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) { - let range = extract_unicode_error_range(&err, vm)?; - let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let s_encoding = PyStrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?; - let (_, standard_encoding) = get_standard_encoding(s_encoding.as_str()); - if let StandardEncoding::Unknown = standard_encoding { - // Not supported, fail with original exception - return Err(err.downcast().unwrap()); - } - let s_after_start = - crate::common::str::try_get_codepoints(s.as_wtf8(), range.start..).unwrap_or_default(); - let num_chars = range.len(); - let mut out: Vec = Vec::with_capacity(num_chars * 4); - for c in s_after_start.code_points().take(num_chars) { - let c = c.to_u32(); - if !(0xd800..=0xdfff).contains(&c) { - // Not a surrogate, fail with original exception - return Err(err.downcast().unwrap()); - } - match standard_encoding { - StandardEncoding::Utf8 => { - out.push((0xe0 | (c >> 12)) as u8); - out.push((0x80 | ((c >> 6) & 0x3f)) as u8); - out.push((0x80 | (c & 0x3f)) as u8); - } - StandardEncoding::Utf16Le => { - out.push(c as u8); - out.push((c >> 8) as u8); - } - StandardEncoding::Utf16Be => { - out.push((c >> 8) as u8); - out.push(c as u8); - } - StandardEncoding::Utf32Le => { - out.push(c as u8); - out.push((c >> 8) as u8); - out.push((c >> 16) as u8); - out.push((c >> 24) as u8); - } - StandardEncoding::Utf32Be => { - out.push((c >> 24) as u8); - out.push((c >> 16) as u8); - out.push((c >> 8) as u8); - out.push(c as u8); - } - StandardEncoding::Unknown => { - unreachable!("NOTE: RUSTPYTHON, should've bailed out earlier") - } - } - } - Ok((vm.ctx.new_bytes(out).into(), range.end)) + if is_encode_err(&err, vm) { + call_native_encode_error(SurrogatePass, err, vm) } else if is_decode_err(&err, vm) { - let range = extract_unicode_error_range(&err, vm)?; - let s = PyBytesRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let s_encoding = PyStrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?; - let (byte_length, standard_encoding) = get_standard_encoding(s_encoding.as_str()); - if let StandardEncoding::Unknown = standard_encoding { - // Not supported, fail with original exception - return Err(err.downcast().unwrap()); - } - - debug_assert!(range.start <= 0.max(s.len() - 1)); - debug_assert!(range.end >= 1.min(s.len())); - debug_assert!(range.end <= s.len()); - - let mut c: u32 = 0; - // Try decoding a single surrogate character. If there are more, - // let the codec call us again. - let p = &s.as_bytes()[range.start..]; - if p.len().overflowing_sub(range.start).0 >= byte_length { - match standard_encoding { - StandardEncoding::Utf8 => { - if (p[0] as u32 & 0xf0) == 0xe0 - && (p[1] as u32 & 0xc0) == 0x80 - && (p[2] as u32 & 0xc0) == 0x80 - { - // it's a three-byte code - c = ((p[0] as u32 & 0x0f) << 12) - + ((p[1] as u32 & 0x3f) << 6) - + (p[2] as u32 & 0x3f); - } - } - StandardEncoding::Utf16Le => { - c = ((p[1] as u32) << 8) | p[0] as u32; - } - StandardEncoding::Utf16Be => { - c = ((p[0] as u32) << 8) | p[1] as u32; - } - StandardEncoding::Utf32Le => { - c = ((p[3] as u32) << 24) - | ((p[2] as u32) << 16) - | ((p[1] as u32) << 8) - | p[0] as u32; - } - StandardEncoding::Utf32Be => { - c = ((p[0] as u32) << 24) - | ((p[1] as u32) << 16) - | ((p[2] as u32) << 8) - | p[3] as u32; - } - StandardEncoding::Unknown => { - unreachable!("NOTE: RUSTPYTHON, should've bailed out earlier") - } - } - } - // !Py_UNICODE_IS_SURROGATE - if !(0xd800..=0xdfff).contains(&c) { - // Not a surrogate, fail with original exception - return Err(err.downcast().unwrap()); - } - - Ok(( - vm.new_pyobj(CodePoint::from_u32(c).unwrap()), - range.start + byte_length, - )) + call_native_decode_error(SurrogatePass, err, vm) } else { Err(bad_err_type(err, vm)) } } fn surrogateescape_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> { - if err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) { - let range = extract_unicode_error_range(&err, vm)?; - let object = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let s_after_start = crate::common::str::try_get_codepoints(object.as_wtf8(), range.start..) - .unwrap_or_default(); - let mut out: Vec = Vec::with_capacity(range.len()); - for ch in s_after_start.code_points().take(range.len()) { - let ch = ch.to_u32(); - if !(0xdc80..=0xdcff).contains(&ch) { - // Not a UTF-8b surrogate, fail with original exception - return Err(err.downcast().unwrap()); - } - out.push((ch - 0xdc00) as u8); - } - let out = vm.ctx.new_bytes(out); - Ok((out.into(), range.end)) + if is_encode_err(&err, vm) { + call_native_encode_error(errors::SurrogateEscape, err, vm) } else if is_decode_err(&err, vm) { - let range = extract_unicode_error_range(&err, vm)?; - let object = err.get_attr("object", vm)?; - let object = PyBytesRef::try_from_object(vm, object)?; - let p = &object.as_bytes()[range.clone()]; - let mut consumed = 0; - let mut replace = Wtf8Buf::with_capacity(4 * range.len()); - while consumed < 4 && consumed < range.len() { - let c = p[consumed] as u16; - // Refuse to escape ASCII bytes - if c < 128 { - break; - } - replace.push(CodePoint::from(0xdc00 + c)); - consumed += 1; - } - if consumed == 0 { - return Err(err.downcast().unwrap()); - } - Ok((vm.new_pyobj(replace), range.start + consumed)) + call_native_decode_error(errors::SurrogateEscape, err, vm) } else { Err(bad_err_type(err, vm)) } diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 9bf7372794..58f2a51b68 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -971,22 +971,10 @@ impl ExceptionZoo { extend_exception!(PySystemError, ctx, excs.system_error); extend_exception!(PyTypeError, ctx, excs.type_error); extend_exception!(PyValueError, ctx, excs.value_error); - extend_exception!(PyUnicodeError, ctx, excs.unicode_error, { - "encoding" => ctx.new_readonly_getset("encoding", excs.unicode_error, make_arg_getter(0)), - "object" => ctx.new_readonly_getset("object", excs.unicode_error, make_arg_getter(1)), - "start" => ctx.new_readonly_getset("start", excs.unicode_error, make_arg_getter(2)), - "end" => ctx.new_readonly_getset("end", excs.unicode_error, make_arg_getter(3)), - "reason" => ctx.new_readonly_getset("reason", excs.unicode_error, make_arg_getter(4)), - }); + extend_exception!(PyUnicodeError, ctx, excs.unicode_error); extend_exception!(PyUnicodeDecodeError, ctx, excs.unicode_decode_error); extend_exception!(PyUnicodeEncodeError, ctx, excs.unicode_encode_error); - extend_exception!(PyUnicodeTranslateError, ctx, excs.unicode_translate_error, { - "encoding" => ctx.new_readonly_getset("encoding", excs.unicode_translate_error, none_getter), - "object" => ctx.new_readonly_getset("object", excs.unicode_translate_error, make_arg_getter(0)), - "start" => ctx.new_readonly_getset("start", excs.unicode_translate_error, make_arg_getter(1)), - "end" => ctx.new_readonly_getset("end", excs.unicode_translate_error, make_arg_getter(2)), - "reason" => ctx.new_readonly_getset("reason", excs.unicode_translate_error, make_arg_getter(3)), - }); + extend_exception!(PyUnicodeTranslateError, ctx, excs.unicode_translate_error); #[cfg(feature = "jit")] extend_exception!(PyJitError, ctx, excs.jit_error); @@ -1010,10 +998,6 @@ impl ExceptionZoo { } } -fn none_getter(_obj: PyObjectRef, vm: &VirtualMachine) -> PyRef { - vm.ctx.none.clone() -} - fn make_arg_getter(idx: usize) -> impl Fn(PyBaseExceptionRef) -> Option { move |exc| exc.get_arg(idx) } @@ -1182,11 +1166,12 @@ pub(super) mod types { PyInt, PyStrRef, PyTupleRef, PyTypeRef, traceback::PyTracebackRef, tuple::IntoPyTuple, }, convert::ToPyResult, - function::FuncArgs, + function::{ArgBytesLike, FuncArgs}, types::{Constructor, Initializer}, }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; + use rustpython_common::str::UnicodeEscapeCodepoint; // This module is designed to be used as `use builtins::*;`. // Do not add any pub symbols not included in builtins module. @@ -1662,18 +1647,153 @@ pub(super) mod types { #[derive(Debug)] pub struct PyUnicodeError {} - #[pyexception(name, base = "PyUnicodeError", ctx = "unicode_decode_error", impl)] + #[pyexception(name, base = "PyUnicodeError", ctx = "unicode_decode_error")] #[derive(Debug)] pub struct PyUnicodeDecodeError {} - #[pyexception(name, base = "PyUnicodeError", ctx = "unicode_encode_error", impl)] + #[pyexception] + impl PyUnicodeDecodeError { + #[pyslot] + #[pymethod(name = "__init__")] + pub(crate) fn slot_init( + zelf: PyObjectRef, + args: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult<()> { + type Args = (PyStrRef, ArgBytesLike, isize, isize, PyStrRef); + let (encoding, object, start, end, reason): Args = args.bind(vm)?; + zelf.set_attr("encoding", encoding, vm)?; + zelf.set_attr("object", object, vm)?; + zelf.set_attr("start", vm.ctx.new_int(start), vm)?; + zelf.set_attr("end", vm.ctx.new_int(end), vm)?; + zelf.set_attr("reason", reason, vm)?; + Ok(()) + } + + #[pymethod(magic)] + fn str(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { + let Ok(object) = exc.as_object().get_attr("object", vm) else { + return Ok("".to_owned()); + }; + let object: ArgBytesLike = object.try_into_value(vm)?; + let encoding: PyStrRef = exc + .as_object() + .get_attr("encoding", vm)? + .try_into_value(vm)?; + let start: usize = exc.as_object().get_attr("start", vm)?.try_into_value(vm)?; + let end: usize = exc.as_object().get_attr("end", vm)?.try_into_value(vm)?; + let reason: PyStrRef = exc.as_object().get_attr("reason", vm)?.try_into_value(vm)?; + if start < object.len() && end <= object.len() && end == start + 1 { + let b = object.borrow_buf()[start]; + Ok(format!( + "'{encoding}' codec can't decode byte {b:#02x} in position {start}: {reason}" + )) + } else { + Ok(format!( + "'{encoding}' codec can't decode bytes in position {start}-{}: {reason}", + end - 1, + )) + } + } + } + + #[pyexception(name, base = "PyUnicodeError", ctx = "unicode_encode_error")] #[derive(Debug)] pub struct PyUnicodeEncodeError {} - #[pyexception(name, base = "PyUnicodeError", ctx = "unicode_translate_error", impl)] + #[pyexception] + impl PyUnicodeEncodeError { + #[pyslot] + #[pymethod(name = "__init__")] + pub(crate) fn slot_init( + zelf: PyObjectRef, + args: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult<()> { + type Args = (PyStrRef, PyStrRef, isize, isize, PyStrRef); + let (encoding, object, start, end, reason): Args = args.bind(vm)?; + zelf.set_attr("encoding", encoding, vm)?; + zelf.set_attr("object", object, vm)?; + zelf.set_attr("start", vm.ctx.new_int(start), vm)?; + zelf.set_attr("end", vm.ctx.new_int(end), vm)?; + zelf.set_attr("reason", reason, vm)?; + Ok(()) + } + + #[pymethod(magic)] + fn str(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { + let Ok(object) = exc.as_object().get_attr("object", vm) else { + return Ok("".to_owned()); + }; + let object: PyStrRef = object.try_into_value(vm)?; + let encoding: PyStrRef = exc + .as_object() + .get_attr("encoding", vm)? + .try_into_value(vm)?; + let start: usize = exc.as_object().get_attr("start", vm)?.try_into_value(vm)?; + let end: usize = exc.as_object().get_attr("end", vm)?.try_into_value(vm)?; + let reason: PyStrRef = exc.as_object().get_attr("reason", vm)?.try_into_value(vm)?; + if start < object.char_len() && end <= object.char_len() && end == start + 1 { + let ch = object.as_wtf8().code_points().nth(start).unwrap(); + Ok(format!( + "'{encoding}' codec can't encode character '{}' in position {start}: {reason}", + UnicodeEscapeCodepoint(ch) + )) + } else { + Ok(format!( + "'{encoding}' codec can't encode characters in position {start}-{}: {reason}", + end - 1, + )) + } + } + } + + #[pyexception(name, base = "PyUnicodeError", ctx = "unicode_translate_error")] #[derive(Debug)] pub struct PyUnicodeTranslateError {} + #[pyexception] + impl PyUnicodeTranslateError { + #[pyslot] + #[pymethod(name = "__init__")] + pub(crate) fn slot_init( + zelf: PyObjectRef, + args: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult<()> { + type Args = (PyStrRef, isize, isize, PyStrRef); + let (object, start, end, reason): Args = args.bind(vm)?; + zelf.set_attr("object", object, vm)?; + zelf.set_attr("start", vm.ctx.new_int(start), vm)?; + zelf.set_attr("end", vm.ctx.new_int(end), vm)?; + zelf.set_attr("reason", reason, vm)?; + Ok(()) + } + + #[pymethod(magic)] + fn str(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { + let Ok(object) = exc.as_object().get_attr("object", vm) else { + return Ok("".to_owned()); + }; + let object: PyStrRef = object.try_into_value(vm)?; + let start: usize = exc.as_object().get_attr("start", vm)?.try_into_value(vm)?; + let end: usize = exc.as_object().get_attr("end", vm)?.try_into_value(vm)?; + let reason: PyStrRef = exc.as_object().get_attr("reason", vm)?.try_into_value(vm)?; + if start < object.char_len() && end <= object.char_len() && end == start + 1 { + let ch = object.as_wtf8().code_points().nth(start).unwrap(); + Ok(format!( + "can't translate character '{}' in position {start}: {reason}", + UnicodeEscapeCodepoint(ch) + )) + } else { + Ok(format!( + "can't translate characters in position {start}-{}: {reason}", + end - 1, + )) + } + } + } + /// JIT error. #[cfg(feature = "jit")] #[pyexception(name, base = "PyException", ctx = "jit_error", impl)] diff --git a/vm/src/function/buffer.rs b/vm/src/function/buffer.rs index 80b36833e5..40a0e04d7e 100644 --- a/vm/src/function/buffer.rs +++ b/vm/src/function/buffer.rs @@ -70,6 +70,12 @@ impl From for PyBuffer { } } +impl From for PyObjectRef { + fn from(buffer: ArgBytesLike) -> Self { + buffer.as_object().to_owned() + } +} + impl<'a> TryFromBorrowedObject<'a> for ArgBytesLike { fn try_from_borrowed_object(vm: &VirtualMachine, obj: &'a PyObject) -> PyResult { let buffer = PyBuffer::try_from_borrowed_object(vm, obj)?; diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index 664fe00616..6ad2a74f4b 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -2,16 +2,15 @@ pub(crate) use _codecs::make_module; #[pymodule] mod _codecs { + use crate::codecs::{ErrorsHandler, PyDecodeContext, PyEncodeContext}; use crate::common::encodings; - use crate::common::str::StrKind; - use crate::common::wtf8::{Wtf8, Wtf8Buf}; + use crate::common::wtf8::Wtf8Buf; use crate::{ - AsObject, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, VirtualMachine, - builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyStr, PyStrRef, PyTuple}, + AsObject, PyObjectRef, PyResult, VirtualMachine, + builtins::PyStrRef, codecs, function::{ArgBytesLike, FuncArgs}, }; - use std::ops::Range; #[pyfunction] fn register(search_function: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { @@ -79,164 +78,6 @@ mod _codecs { vm.state.codec_registry.lookup_error(name.as_str(), vm) } - struct ErrorsHandler<'a> { - vm: &'a VirtualMachine, - encoding: &'a str, - errors: Option, - handler: once_cell::unsync::OnceCell, - } - impl<'a> ErrorsHandler<'a> { - #[inline] - fn new(encoding: &'a str, errors: Option, vm: &'a VirtualMachine) -> Self { - ErrorsHandler { - vm, - encoding, - errors, - handler: Default::default(), - } - } - #[inline] - fn handler_func(&self) -> PyResult<&PyObject> { - let vm = self.vm; - Ok(self.handler.get_or_try_init(|| { - let errors = self.errors.as_ref().map_or("strict", |s| s.as_str()); - vm.state.codec_registry.lookup_error(errors, vm) - })?) - } - } - impl encodings::StrBuffer for PyStrRef { - fn is_compatible_with(&self, kind: StrKind) -> bool { - self.kind() <= kind - } - } - impl encodings::ErrorHandler for ErrorsHandler<'_> { - type Error = PyBaseExceptionRef; - type StrBuf = PyStrRef; - type BytesBuf = PyBytesRef; - - fn handle_encode_error( - &self, - data: &Wtf8, - char_range: Range, - reason: &str, - ) -> PyResult<(encodings::EncodeReplace, usize)> { - let vm = self.vm; - let data_str = vm.ctx.new_str(data).into(); - let encode_exc = vm.new_exception( - vm.ctx.exceptions.unicode_encode_error.to_owned(), - vec![ - vm.ctx.new_str(self.encoding).into(), - data_str, - vm.ctx.new_int(char_range.start).into(), - vm.ctx.new_int(char_range.end).into(), - vm.ctx.new_str(reason).into(), - ], - ); - let res = self.handler_func()?.call((encode_exc,), vm)?; - let tuple_err = || { - vm.new_type_error( - "encoding error handler must return (str/bytes, int) tuple".to_owned(), - ) - }; - let (replace, restart) = match res.payload::().map(|tup| tup.as_slice()) { - Some([replace, restart]) => (replace.clone(), restart), - _ => return Err(tuple_err()), - }; - let replace = match_class!(match replace { - s @ PyStr => encodings::EncodeReplace::Str(s), - b @ PyBytes => encodings::EncodeReplace::Bytes(b), - _ => return Err(tuple_err()), - }); - let restart = isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?; - let restart = if restart < 0 { - // will still be out of bounds if it underflows ¯\_(ツ)_/¯ - data.len().wrapping_sub(restart.unsigned_abs()) - } else { - restart as usize - }; - Ok((replace, restart)) - } - - fn handle_decode_error( - &self, - data: &[u8], - byte_range: Range, - reason: &str, - ) -> PyResult<(PyStrRef, Option, usize)> { - let vm = self.vm; - let data_bytes: PyObjectRef = vm.ctx.new_bytes(data.to_vec()).into(); - let decode_exc = vm.new_exception( - vm.ctx.exceptions.unicode_decode_error.to_owned(), - vec![ - vm.ctx.new_str(self.encoding).into(), - data_bytes.clone(), - vm.ctx.new_int(byte_range.start).into(), - vm.ctx.new_int(byte_range.end).into(), - vm.ctx.new_str(reason).into(), - ], - ); - let handler = self.handler_func()?; - let res = handler.call((decode_exc.clone(),), vm)?; - let new_data = decode_exc - .get_arg(1) - .ok_or_else(|| vm.new_type_error("object attribute not set".to_owned()))?; - let new_data = if new_data.is(&data_bytes) { - None - } else { - let new_data: PyBytesRef = new_data - .downcast() - .map_err(|_| vm.new_type_error("object attribute must be bytes".to_owned()))?; - Some(new_data) - }; - let data = new_data.as_ref().map_or(data, |s| s.as_ref()); - let tuple_err = || { - vm.new_type_error("decoding error handler must return (str, int) tuple".to_owned()) - }; - match res.payload::().map(|tup| tup.as_slice()) { - Some([replace, restart]) => { - let replace = replace - .downcast_ref::() - .ok_or_else(tuple_err)? - .to_owned(); - let restart = - isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?; - let restart = if restart < 0 { - // will still be out of bounds if it underflows ¯\_(ツ)_/¯ - data.len().wrapping_sub(restart.unsigned_abs()) - } else { - restart as usize - }; - Ok((replace, new_data, restart)) - } - _ => Err(tuple_err()), - } - } - - fn error_oob_restart(&self, i: usize) -> PyBaseExceptionRef { - self.vm - .new_index_error(format!("position {i} from error handler out of bounds")) - } - - fn error_encoding( - &self, - data: &Wtf8, - char_range: Range, - reason: &str, - ) -> Self::Error { - let vm = self.vm; - vm.new_exception( - vm.ctx.exceptions.unicode_encode_error.to_owned(), - vec![ - vm.ctx.new_str(self.encoding).into(), - vm.ctx.new_str(data).into(), - vm.ctx.new_int(char_range.start).into(), - vm.ctx.new_int(char_range.end).into(), - vm.ctx.new_str(reason).into(), - ], - ) - } - } - type EncodeResult = PyResult<(Vec, usize)>; #[derive(FromArgs)] @@ -249,12 +90,13 @@ mod _codecs { impl EncodeArgs { #[inline] - fn encode<'a, F>(self, name: &'a str, encode: F, vm: &'a VirtualMachine) -> EncodeResult + fn encode<'a, F>(&'a self, name: &'a str, encode: F, vm: &'a VirtualMachine) -> EncodeResult where - F: FnOnce(&Wtf8, &ErrorsHandler<'a>) -> PyResult>, + F: FnOnce(PyEncodeContext<'a>, &ErrorsHandler<'a>) -> PyResult>, { - let errors = ErrorsHandler::new(name, self.errors, vm); - let encoded = encode(self.s.as_wtf8(), &errors)?; + let ctx = PyEncodeContext::new(name, &self.s, vm); + let errors = ErrorsHandler::new(self.errors.as_deref(), vm); + let encoded = encode(ctx, &errors)?; Ok((encoded, self.s.char_len())) } } @@ -273,13 +115,13 @@ mod _codecs { impl DecodeArgs { #[inline] - fn decode<'a, F>(self, name: &'a str, decode: F, vm: &'a VirtualMachine) -> DecodeResult + fn decode<'a, F>(&'a self, name: &'a str, decode: F, vm: &'a VirtualMachine) -> DecodeResult where - F: FnOnce(&[u8], &ErrorsHandler<'a>, bool) -> DecodeResult, + F: FnOnce(PyDecodeContext<'a>, &ErrorsHandler<'a>, bool) -> DecodeResult, { - let data = self.data.borrow_buf(); - let errors = ErrorsHandler::new(name, self.errors, vm); - decode(&data, &errors, self.final_decode) + let ctx = PyDecodeContext::new(name, &self.data, vm); + let errors = ErrorsHandler::new(self.errors.as_deref(), vm); + decode(ctx, &errors, self.final_decode) } } @@ -293,13 +135,13 @@ mod _codecs { impl DecodeArgsNoFinal { #[inline] - fn decode<'a, F>(self, name: &'a str, decode: F, vm: &'a VirtualMachine) -> DecodeResult + fn decode<'a, F>(&'a self, name: &'a str, decode: F, vm: &'a VirtualMachine) -> DecodeResult where - F: FnOnce(&[u8], &ErrorsHandler<'a>) -> DecodeResult, + F: FnOnce(PyDecodeContext<'a>, &ErrorsHandler<'a>) -> DecodeResult, { - let data = self.data.borrow_buf(); - let errors = ErrorsHandler::new(name, self.errors, vm); - decode(&data, &errors) + let ctx = PyDecodeContext::new(name, &self.data, vm); + let errors = ErrorsHandler::new(self.errors.as_deref(), vm); + decode(ctx, &errors) } } diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 0b680251fa..77d9231724 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -2536,10 +2536,9 @@ mod _io { *snapshot = Some((cookie.dec_flags, input_chunk.clone())); let decoded = vm.call_method(decoder, "decode", (input_chunk, cookie.need_eof))?; let decoded = check_decoded(decoded, vm)?; - let pos_is_valid = crate::common::wtf8::is_code_point_boundary( - decoded.as_wtf8(), - cookie.bytes_to_skip as usize, - ); + let pos_is_valid = decoded + .as_wtf8() + .is_code_point_boundary(cookie.bytes_to_skip as usize); textio.set_decoded_chars(Some(decoded)); if !pos_is_valid { return Err(vm.new_os_error("can't restore logical file position".to_owned())); diff --git a/vm/src/vm/vm_new.rs b/vm/src/vm/vm_new.rs index 12241414a7..3ceb783a48 100644 --- a/vm/src/vm/vm_new.rs +++ b/vm/src/vm/vm_new.rs @@ -1,7 +1,8 @@ use crate::{ AsObject, Py, PyObject, PyObjectRef, PyRef, builtins::{ - PyBaseException, PyBaseExceptionRef, PyDictRef, PyModule, PyStrRef, PyType, PyTypeRef, + PyBaseException, PyBaseExceptionRef, PyBytesRef, PyDictRef, PyModule, PyStrRef, PyType, + PyTypeRef, builtin_func::PyNativeFunction, descriptor::PyMethodDescriptor, tuple::{IntoPyTuple, PyTupleRef}, @@ -203,16 +204,78 @@ impl VirtualMachine { self.new_exception_msg(sys_error, msg) } + // TODO: remove & replace with new_unicode_decode_error_real pub fn new_unicode_decode_error(&self, msg: String) -> PyBaseExceptionRef { let unicode_decode_error = self.ctx.exceptions.unicode_decode_error.to_owned(); self.new_exception_msg(unicode_decode_error, msg) } + pub fn new_unicode_decode_error_real( + &self, + encoding: PyStrRef, + object: PyBytesRef, + start: usize, + end: usize, + reason: PyStrRef, + ) -> PyBaseExceptionRef { + let start = self.ctx.new_int(start); + let end = self.ctx.new_int(end); + let exc = self.new_exception( + self.ctx.exceptions.unicode_decode_error.to_owned(), + vec![ + encoding.clone().into(), + object.clone().into(), + start.clone().into(), + end.clone().into(), + reason.clone().into(), + ], + ); + exc.as_object() + .set_attr("encoding", encoding, self) + .unwrap(); + exc.as_object().set_attr("object", object, self).unwrap(); + exc.as_object().set_attr("start", start, self).unwrap(); + exc.as_object().set_attr("end", end, self).unwrap(); + exc.as_object().set_attr("reason", reason, self).unwrap(); + exc + } + + // TODO: remove & replace with new_unicode_encode_error_real pub fn new_unicode_encode_error(&self, msg: String) -> PyBaseExceptionRef { let unicode_encode_error = self.ctx.exceptions.unicode_encode_error.to_owned(); self.new_exception_msg(unicode_encode_error, msg) } + pub fn new_unicode_encode_error_real( + &self, + encoding: PyStrRef, + object: PyStrRef, + start: usize, + end: usize, + reason: PyStrRef, + ) -> PyBaseExceptionRef { + let start = self.ctx.new_int(start); + let end = self.ctx.new_int(end); + let exc = self.new_exception( + self.ctx.exceptions.unicode_encode_error.to_owned(), + vec![ + encoding.clone().into(), + object.clone().into(), + start.clone().into(), + end.clone().into(), + reason.clone().into(), + ], + ); + exc.as_object() + .set_attr("encoding", encoding, self) + .unwrap(); + exc.as_object().set_attr("object", object, self).unwrap(); + exc.as_object().set_attr("start", start, self).unwrap(); + exc.as_object().set_attr("end", end, self).unwrap(); + exc.as_object().set_attr("reason", reason, self).unwrap(); + exc + } + /// Create a new python ValueError object. Useful for raising errors from /// python functions implemented in rust. pub fn new_value_error(&self, msg: String) -> PyBaseExceptionRef {