From c6cab4c43af41368e03153afbd509bbd307de337 Mon Sep 17 00:00:00 2001 From: Noa Date: Wed, 26 Mar 2025 20:35:59 -0500 Subject: [PATCH 1/3] Parse surrogates in string literals properly --- Cargo.lock | 3 + Lib/test/test_codeccallbacks.py | 6 - common/src/encodings.rs | 2 +- common/src/wtf8/mod.rs | 53 +++++ compiler/codegen/Cargo.toml | 2 + compiler/codegen/src/compile.rs | 138 ++++++++----- compiler/codegen/src/lib.rs | 1 + compiler/codegen/src/string_parser.rs | 287 ++++++++++++++++++++++++++ compiler/core/Cargo.toml | 1 + compiler/core/src/bytecode.rs | 5 +- compiler/core/src/marshal.rs | 12 +- jit/tests/common.rs | 4 +- vm/src/builtins/code.rs | 2 +- vm/src/builtins/str.rs | 12 ++ vm/src/intern.rs | 53 +++-- vm/src/stdlib/marshal.rs | 5 +- 16 files changed, 506 insertions(+), 80 deletions(-) create mode 100644 compiler/codegen/src/string_parser.rs diff --git a/Cargo.lock b/Cargo.lock index 241d2a595d..70d58bc448 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2319,6 +2319,7 @@ dependencies = [ "itertools 0.14.0", "log", "malachite-bigint", + "memchr", "num-complex", "num-traits", "ruff_python_ast", @@ -2330,6 +2331,7 @@ dependencies = [ "rustpython-compiler-core", "rustpython-compiler-source", "thiserror 2.0.11", + "unicode_names2", ] [[package]] @@ -2387,6 +2389,7 @@ dependencies = [ "ruff_python_ast", "ruff_python_parser", "ruff_source_file", + "rustpython-common", "serde", ] diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py index 09a6d883f8..bd1dbcd626 100644 --- a/Lib/test/test_codeccallbacks.py +++ b/Lib/test/test_codeccallbacks.py @@ -536,8 +536,6 @@ def test_badandgoodxmlcharrefreplaceexceptions(self): ("".join("&#%d;" % c for c in cs), 1 + len(s)) ) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_badandgoodbackslashreplaceexceptions(self): # "backslashreplace" complains about a non-exception passed in self.assertRaises( @@ -596,8 +594,6 @@ def test_badandgoodbackslashreplaceexceptions(self): (r, 2) ) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_badandgoodnamereplaceexceptions(self): # "namereplace" complains about a non-exception passed in self.assertRaises( @@ -644,8 +640,6 @@ def test_badandgoodnamereplaceexceptions(self): (r, 1 + len(s)) ) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_badandgoodsurrogateescapeexceptions(self): surrogateescape_errors = codecs.lookup_error('surrogateescape') # "surrogateescape" complains about a non-exception passed in diff --git a/common/src/encodings.rs b/common/src/encodings.rs index 097dae17ba..c444e27a5a 100644 --- a/common/src/encodings.rs +++ b/common/src/encodings.rs @@ -401,7 +401,7 @@ pub mod errors { let mut out = String::with_capacity(num_chars * 4); for c in err_str.code_points() { let c_u32 = c.to_u32(); - if let Some(c_name) = unicode_names2::name(c.to_char_lossy()) { + if let Some(c_name) = c.to_char().and_then(unicode_names2::name) { write!(out, "\\N{{{c_name}}}").unwrap(); } else if c_u32 >= 0x10000 { write!(out, "\\U{c_u32:08x}").unwrap(); diff --git a/common/src/wtf8/mod.rs b/common/src/wtf8/mod.rs index f6ae628bad..21c5de28bd 100644 --- a/common/src/wtf8/mod.rs +++ b/common/src/wtf8/mod.rs @@ -574,6 +574,12 @@ impl> FromIterator for Wtf8Buf { } } +impl Hash for Wtf8Buf { + fn hash(&self, state: &mut H) { + Wtf8::hash(self, state) + } +} + impl AsRef for Wtf8Buf { fn as_ref(&self) -> &Wtf8 { self @@ -692,6 +698,13 @@ impl Default for &Wtf8 { } } +impl Hash for Wtf8 { + fn hash(&self, state: &mut H) { + state.write(self.as_bytes()); + state.write_u8(0xff); + } +} + impl Wtf8 { /// Creates a WTF-8 slice from a UTF-8 `&str` slice. /// @@ -722,6 +735,32 @@ impl Wtf8 { unsafe { &mut *(value as *mut [u8] as *mut Wtf8) } } + /// Create a WTF-8 slice from a WTF-8 byte slice. + // + // whooops! using WTF-8 for interchange! + #[inline] + pub fn from_bytes(b: &[u8]) -> Option<&Self> { + let mut rest = b; + while let Err(e) = std::str::from_utf8(rest) { + rest = &rest[e.valid_up_to()..]; + Self::decode_surrogate(rest)?; + rest = &rest[3..]; + } + Some(unsafe { Wtf8::from_bytes_unchecked(b) }) + } + + fn decode_surrogate(b: &[u8]) -> Option { + let [a, b, c, ..] = *b else { return None }; + if (a & 0xf0) == 0xe0 && (b & 0xc0) == 0x80 && (c & 0xc0) == 0x80 { + // it's a three-byte code + let c = ((a as u32 & 0x0f) << 12) + ((b as u32 & 0x3f) << 6) + (c as u32 & 0x3f); + let 0xD800..=0xDFFF = c else { return None }; + Some(CodePoint { value: c }) + } else { + None + } + } + /// Returns the length, in WTF-8 bytes. #[inline] pub fn len(&self) -> usize { @@ -875,6 +914,14 @@ impl Wtf8 { } } + #[inline] + fn final_lead_surrogate(&self) -> Option { + match self.bytes { + [.., 0xED, b2 @ 0xA0..=0xAF, b3] => Some(decode_surrogate(b2, b3)), + _ => None, + } + } + pub fn is_code_point_boundary(&self, index: usize) -> bool { is_code_point_boundary(self, index) } @@ -1481,6 +1528,12 @@ impl From for Box { } } +impl From> for Wtf8Buf { + fn from(w: Box) -> Self { + Wtf8Buf::from_box(w) + } +} + impl From for Box { fn from(s: String) -> Self { s.into_boxed_str().into() diff --git a/compiler/codegen/Cargo.toml b/compiler/codegen/Cargo.toml index 9cf93bc22b..c7ff439f78 100644 --- a/compiler/codegen/Cargo.toml +++ b/compiler/codegen/Cargo.toml @@ -30,6 +30,8 @@ num-complex = { workspace = true } num-traits = { workspace = true } thiserror = { workspace = true } malachite-bigint = { workspace = true } +memchr = { workspace = true } +unicode_names2 = { workspace = true } [dev-dependencies] # rustpython-parser = { workspace = true } diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index a6eb216e2a..a03a1fdb50 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -21,13 +21,14 @@ use ruff_python_ast::{ Alias, Arguments, BoolOp, CmpOp, Comprehension, ConversionFlag, DebugText, Decorator, DictItem, ExceptHandler, ExceptHandlerExceptHandler, Expr, ExprAttribute, ExprBoolOp, ExprFString, ExprList, ExprName, ExprStarred, ExprSubscript, ExprTuple, ExprUnaryOp, FString, - FStringElement, FStringElements, FStringPart, Int, Keyword, MatchCase, ModExpression, - ModModule, Operator, Parameters, Pattern, PatternMatchAs, PatternMatchValue, Stmt, StmtExpr, - TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, TypeParams, UnaryOp, - WithItem, + FStringElement, FStringElements, FStringFlags, FStringPart, Int, Keyword, MatchCase, + ModExpression, ModModule, Operator, Parameters, Pattern, PatternMatchAs, PatternMatchValue, + Stmt, StmtExpr, TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, + TypeParams, UnaryOp, WithItem, }; use ruff_source_file::OneIndexed; use ruff_text_size::{Ranged, TextRange}; +use rustpython_common::wtf8::Wtf8Buf; // use rustpython_ast::located::{self as located_ast, Located}; use rustpython_compiler_core::{ Mode, @@ -375,7 +376,9 @@ impl Compiler<'_> { let (doc, statements) = split_doc(&body.body, &self.opts); if let Some(value) = doc { - self.emit_load_const(ConstantData::Str { value }); + self.emit_load_const(ConstantData::Str { + value: value.into(), + }); let doc = self.name("__doc__"); emit!(self, Instruction::StoreGlobal(doc)) } @@ -636,14 +639,12 @@ impl Compiler<'_> { statement.range(), )); } - vec![ConstantData::Str { - value: "*".to_owned(), - }] + vec![ConstantData::Str { value: "*".into() }] } else { names .iter() .map(|n| ConstantData::Str { - value: n.name.to_string(), + value: n.name.as_str().into(), }) .collect() }; @@ -954,7 +955,7 @@ impl Compiler<'_> { self.pop_symbol_table(); } self.emit_load_const(ConstantData::Str { - value: name_string.clone(), + value: name_string.clone().into(), }); emit!(self, Instruction::TypeAlias); self.store_name(&name_string)?; @@ -1028,7 +1029,7 @@ impl Compiler<'_> { let default_kw_count = kw_with_defaults.len(); for (arg, default) in kw_with_defaults.iter() { self.emit_load_const(ConstantData::Str { - value: arg.name.to_string(), + value: arg.name.as_str().into(), }); self.compile_expression(default)?; } @@ -1101,7 +1102,7 @@ impl Compiler<'_> { if let Some(expr) = &bound { self.compile_expression(expr)?; self.emit_load_const(ConstantData::Str { - value: name.to_string(), + value: name.as_str().into(), }); emit!(self, Instruction::TypeVarWithBound); emit!(self, Instruction::Duplicate); @@ -1109,7 +1110,7 @@ impl Compiler<'_> { } else { // self.store_name(type_name.as_str())?; self.emit_load_const(ConstantData::Str { - value: name.to_string(), + value: name.as_str().into(), }); emit!(self, Instruction::TypeVar); emit!(self, Instruction::Duplicate); @@ -1118,7 +1119,7 @@ impl Compiler<'_> { } TypeParam::ParamSpec(TypeParamParamSpec { name, .. }) => { self.emit_load_const(ConstantData::Str { - value: name.to_string(), + value: name.as_str().into(), }); emit!(self, Instruction::ParamSpec); emit!(self, Instruction::Duplicate); @@ -1126,7 +1127,7 @@ impl Compiler<'_> { } TypeParam::TypeVarTuple(TypeParamTypeVarTuple { name, .. }) => { self.emit_load_const(ConstantData::Str { - value: name.to_string(), + value: name.as_str().into(), }); emit!(self, Instruction::TypeVarTuple); emit!(self, Instruction::Duplicate); @@ -1363,7 +1364,7 @@ impl Compiler<'_> { if let Some(annotation) = returns { // key: self.emit_load_const(ConstantData::Str { - value: "return".to_owned(), + value: "return".into(), }); // value: self.compile_annotation(annotation)?; @@ -1380,7 +1381,7 @@ impl Compiler<'_> { for param in parameters_iter { if let Some(annotation) = ¶m.annotation { self.emit_load_const(ConstantData::Str { - value: self.mangle(param.name.as_str()).into_owned(), + value: self.mangle(param.name.as_str()).into_owned().into(), }); self.compile_annotation(annotation)?; num_annotations += 1; @@ -1410,7 +1411,7 @@ impl Compiler<'_> { code: Box::new(code), }); self.emit_load_const(ConstantData::Str { - value: qualified_name, + value: qualified_name.into(), }); // Turn code object into function object: @@ -1418,7 +1419,9 @@ impl Compiler<'_> { if let Some(value) = doc_str { emit!(self, Instruction::Duplicate); - self.emit_load_const(ConstantData::Str { value }); + self.emit_load_const(ConstantData::Str { + value: value.into(), + }); emit!(self, Instruction::Rotate2); let doc = self.name("__doc__"); emit!(self, Instruction::StoreAttr { idx: doc }); @@ -1547,7 +1550,7 @@ impl Compiler<'_> { let dunder_module = self.name("__module__"); emit!(self, Instruction::StoreLocal(dunder_module)); self.emit_load_const(ConstantData::Str { - value: qualified_name, + value: qualified_name.into(), }); let qualname = self.name("__qualname__"); emit!(self, Instruction::StoreLocal(qualname)); @@ -1608,16 +1611,12 @@ impl Compiler<'_> { self.emit_load_const(ConstantData::Code { code: Box::new(code), }); - self.emit_load_const(ConstantData::Str { - value: name.to_owned(), - }); + self.emit_load_const(ConstantData::Str { value: name.into() }); // Turn code object into function object: emit!(self, Instruction::MakeFunction(func_flags)); - self.emit_load_const(ConstantData::Str { - value: name.to_owned(), - }); + self.emit_load_const(ConstantData::Str { value: name.into() }); // Call the __build_class__ builtin let call = if let Some(arguments) = arguments { @@ -1638,7 +1637,7 @@ impl Compiler<'_> { // Doc string value: self.emit_load_const(match doc_str { - Some(doc) => ConstantData::Str { value: doc }, + Some(doc) => ConstantData::Str { value: doc.into() }, None => ConstantData::None, // set docstring None if not declared }); } @@ -2031,7 +2030,7 @@ impl Compiler<'_> { let ident = Default::default(); let codegen = ruff_python_codegen::Generator::new(&ident, Default::default()); self.emit_load_const(ConstantData::Str { - value: codegen.expr(annotation), + value: codegen.expr(annotation).into(), }); } else { self.compile_expression(annotation)?; @@ -2063,7 +2062,7 @@ impl Compiler<'_> { let annotations = self.name("__annotations__"); emit!(self, Instruction::LoadNameAny(annotations)); self.emit_load_const(ConstantData::Str { - value: self.mangle(id.as_str()).into_owned(), + value: self.mangle(id.as_str()).into_owned().into(), }); emit!(self, Instruction::StoreSubscript); } else { @@ -2538,7 +2537,7 @@ impl Compiler<'_> { self.emit_load_const(ConstantData::Code { code: Box::new(code), }); - self.emit_load_const(ConstantData::Str { value: name }); + self.emit_load_const(ConstantData::Str { value: name.into() }); // Turn code object into function object: emit!(self, Instruction::MakeFunction(func_flags)); @@ -2679,9 +2678,23 @@ impl Compiler<'_> { self.compile_expr_fstring(fstring)?; } Expr::StringLiteral(string) => { - self.emit_load_const(ConstantData::Str { - value: string.value.to_str().to_owned(), - }); + let value = string.value.to_str(); + if value.contains(char::REPLACEMENT_CHARACTER) { + let value = string + .value + .iter() + .map(|lit| { + let source = self.source_code.get_range(lit.range); + crate::string_parser::parse_string_literal(source, lit.flags.into()) + }) + .collect(); + // might have a surrogate literal; should reparse to be sure + self.emit_load_const(ConstantData::Str { value }); + } else { + self.emit_load_const(ConstantData::Str { + value: value.into(), + }); + } } Expr::BytesLiteral(bytes) => { let iter = bytes.value.iter().flat_map(|x| x.iter().copied()); @@ -2732,7 +2745,7 @@ impl Compiler<'_> { for keyword in sub_keywords { if let Some(name) = &keyword.arg { self.emit_load_const(ConstantData::Str { - value: name.to_string(), + value: name.as_str().into(), }); self.compile_expression(&keyword.value)?; sub_size += 1; @@ -2822,7 +2835,7 @@ impl Compiler<'_> { for keyword in &arguments.keywords { if let Some(name) = &keyword.arg { kwarg_names.push(ConstantData::Str { - value: name.to_string(), + value: name.as_str().into(), }); } else { // This means **kwargs! @@ -3058,9 +3071,7 @@ impl Compiler<'_> { }); // List comprehension function name: - self.emit_load_const(ConstantData::Str { - value: name.to_owned(), - }); + self.emit_load_const(ConstantData::Str { value: name.into() }); // Turn code object into function object: emit!(self, Instruction::MakeFunction(func_flags)); @@ -3358,9 +3369,19 @@ impl Compiler<'_> { fn compile_fstring_part(&mut self, part: &FStringPart) -> CompileResult<()> { match part { FStringPart::Literal(string) => { - self.emit_load_const(ConstantData::Str { - value: string.value.to_string(), - }); + if string.value.contains(char::REPLACEMENT_CHARACTER) { + // might have a surrogate literal; should reparse to be sure + let source = self.source_code.get_range(string.range); + let value = + crate::string_parser::parse_string_literal(source, string.flags.into()); + self.emit_load_const(ConstantData::Str { + value: value.into(), + }); + } else { + self.emit_load_const(ConstantData::Str { + value: string.value.to_string().into(), + }); + } Ok(()) } FStringPart::FString(fstring) => self.compile_fstring(fstring), @@ -3368,19 +3389,32 @@ impl Compiler<'_> { } fn compile_fstring(&mut self, fstring: &FString) -> CompileResult<()> { - self.compile_fstring_elements(&fstring.elements) + self.compile_fstring_elements(fstring.flags, &fstring.elements) } fn compile_fstring_elements( &mut self, + flags: FStringFlags, fstring_elements: &FStringElements, ) -> CompileResult<()> { for element in fstring_elements { match element { FStringElement::Literal(string) => { - self.emit_load_const(ConstantData::Str { - value: string.value.to_string(), - }); + if string.value.contains(char::REPLACEMENT_CHARACTER) { + // might have a surrogate literal; should reparse to be sure + let source = self.source_code.get_range(string.range); + let value = crate::string_parser::parse_fstring_literal_element( + source.into(), + flags.into(), + ); + self.emit_load_const(ConstantData::Str { + value: value.into(), + }); + } else { + self.emit_load_const(ConstantData::Str { + value: string.value.to_string().into(), + }); + } } FStringElement::Expression(fstring_expr) => { let mut conversion = fstring_expr.conversion; @@ -3393,11 +3427,13 @@ impl Compiler<'_> { let source = source.to_string(); self.emit_load_const(ConstantData::Str { - value: leading.to_string(), + value: leading.to_string().into(), + }); + self.emit_load_const(ConstantData::Str { + value: source.into(), }); - self.emit_load_const(ConstantData::Str { value: source }); self.emit_load_const(ConstantData::Str { - value: trailing.to_string(), + value: trailing.to_string().into(), }); 3 @@ -3407,7 +3443,7 @@ impl Compiler<'_> { match &fstring_expr.format_spec { None => { self.emit_load_const(ConstantData::Str { - value: String::new(), + value: Wtf8Buf::new(), }); // Match CPython behavior: If debug text is present, apply repr conversion. // See: https://github.com/python/cpython/blob/f61afca262d3a0aa6a8a501db0b1936c60858e35/Parser/action_helpers.c#L1456 @@ -3416,7 +3452,7 @@ impl Compiler<'_> { } } Some(format_spec) => { - self.compile_fstring_elements(&format_spec.elements)?; + self.compile_fstring_elements(flags, &format_spec.elements)?; } } @@ -3449,7 +3485,7 @@ impl Compiler<'_> { if element_count == 0 { // ensure to put an empty string on the stack if there aren't any fstring elements self.emit_load_const(ConstantData::Str { - value: String::new(), + value: Wtf8Buf::new(), }); } else if element_count > 1 { emit!( diff --git a/compiler/codegen/src/lib.rs b/compiler/codegen/src/lib.rs index ceadb3c364..d44844543e 100644 --- a/compiler/codegen/src/lib.rs +++ b/compiler/codegen/src/lib.rs @@ -11,6 +11,7 @@ type IndexSet = indexmap::IndexSet; pub mod compile; pub mod error; pub mod ir; +mod string_parser; pub mod symboltable; pub use compile::CompileOpts; diff --git a/compiler/codegen/src/string_parser.rs b/compiler/codegen/src/string_parser.rs new file mode 100644 index 0000000000..7bdb86aa5e --- /dev/null +++ b/compiler/codegen/src/string_parser.rs @@ -0,0 +1,287 @@ +//! A stripped-down version of ruff's string literal parser, modified to +//! handle surrogates in string literals and output WTF-8. +//! +//! Any `unreachable!()` statements in this file are because we only get here +//! after ruff has already successfully parsed the string literal, meaning +//! we don't need to do any validation or error handling. + +use std::convert::Infallible; + +use ruff_python_ast::{AnyStringFlags, StringFlags}; +use rustpython_common::wtf8::{CodePoint, Wtf8, Wtf8Buf}; + +// use ruff_python_parser::{LexicalError, LexicalErrorType}; +type LexicalError = Infallible; + +enum EscapedChar { + Literal(CodePoint), + Escape(char), +} + +struct StringParser { + /// The raw content of the string e.g., the `foo` part in `"foo"`. + source: Box, + /// Current position of the parser in the source. + cursor: usize, + /// Flags that can be used to query information about the string. + flags: AnyStringFlags, +} + +impl StringParser { + fn new(source: Box, flags: AnyStringFlags) -> Self { + Self { + source, + cursor: 0, + flags, + } + } + + #[inline] + fn skip_bytes(&mut self, bytes: usize) -> &str { + let skipped_str = &self.source[self.cursor..self.cursor + bytes]; + self.cursor += bytes; + skipped_str + } + + /// Returns the next byte in the string, if there is one. + /// + /// # Panics + /// + /// When the next byte is a part of a multi-byte character. + #[inline] + fn next_byte(&mut self) -> Option { + self.source[self.cursor..].as_bytes().first().map(|&byte| { + self.cursor += 1; + byte + }) + } + + #[inline] + fn next_char(&mut self) -> Option { + self.source[self.cursor..].chars().next().inspect(|c| { + self.cursor += c.len_utf8(); + }) + } + + #[inline] + fn peek_byte(&self) -> Option { + self.source[self.cursor..].as_bytes().first().copied() + } + + fn parse_unicode_literal(&mut self, literal_number: usize) -> Result { + let mut p: u32 = 0u32; + for i in 1..=literal_number { + match self.next_char() { + Some(c) => match c.to_digit(16) { + Some(d) => p += d << ((literal_number - i) * 4), + None => unreachable!(), + }, + None => unreachable!(), + } + } + Ok(CodePoint::from_u32(p).unwrap()) + } + + fn parse_octet(&mut self, o: u8) -> char { + let mut radix_bytes = [o, 0, 0]; + let mut len = 1; + + while len < 3 { + let Some(b'0'..=b'7') = self.peek_byte() else { + break; + }; + + radix_bytes[len] = self.next_byte().unwrap(); + len += 1; + } + + // OK because radix_bytes is always going to be in the ASCII range. + let radix_str = std::str::from_utf8(&radix_bytes[..len]).expect("ASCII bytes"); + let value = u32::from_str_radix(radix_str, 8).unwrap(); + char::from_u32(value).unwrap() + } + + fn parse_unicode_name(&mut self) -> Result { + let Some('{') = self.next_char() else { + unreachable!() + }; + + let Some(close_idx) = self.source[self.cursor..].find('}') else { + unreachable!() + }; + + let name_and_ending = self.skip_bytes(close_idx + 1); + let name = &name_and_ending[..name_and_ending.len() - 1]; + + unicode_names2::character(name).ok_or_else(|| unreachable!()) + } + + /// Parse an escaped character, returning the new character. + fn parse_escaped_char(&mut self) -> Result, LexicalError> { + let Some(first_char) = self.next_char() else { + unreachable!() + }; + + let new_char = match first_char { + '\\' => '\\'.into(), + '\'' => '\''.into(), + '\"' => '"'.into(), + 'a' => '\x07'.into(), + 'b' => '\x08'.into(), + 'f' => '\x0c'.into(), + 'n' => '\n'.into(), + 'r' => '\r'.into(), + 't' => '\t'.into(), + 'v' => '\x0b'.into(), + o @ '0'..='7' => self.parse_octet(o as u8).into(), + 'x' => self.parse_unicode_literal(2)?, + 'u' if !self.flags.is_byte_string() => self.parse_unicode_literal(4)?, + 'U' if !self.flags.is_byte_string() => self.parse_unicode_literal(8)?, + 'N' if !self.flags.is_byte_string() => self.parse_unicode_name()?.into(), + // Special cases where the escape sequence is not a single character + '\n' => return Ok(None), + '\r' => { + if self.peek_byte() == Some(b'\n') { + self.next_byte(); + } + + return Ok(None); + } + _ => return Ok(Some(EscapedChar::Escape(first_char))), + }; + + Ok(Some(EscapedChar::Literal(new_char))) + } + + fn parse_fstring_middle(mut self) -> Result, LexicalError> { + // Fast-path: if the f-string doesn't contain any escape sequences, return the literal. + let Some(mut index) = memchr::memchr3(b'{', b'}', b'\\', self.source.as_bytes()) else { + return Ok(self.source.into()); + }; + + let mut value = Wtf8Buf::with_capacity(self.source.len()); + loop { + // Add the characters before the escape sequence (or curly brace) to the string. + let before_with_slash_or_brace = self.skip_bytes(index + 1); + let before = &before_with_slash_or_brace[..before_with_slash_or_brace.len() - 1]; + value.push_str(before); + + // Add the escaped character to the string. + match &self.source.as_bytes()[self.cursor - 1] { + // If there are any curly braces inside a `FStringMiddle` token, + // then they were escaped (i.e. `{{` or `}}`). This means that + // we need increase the location by 2 instead of 1. + b'{' => value.push_char('{'), + b'}' => value.push_char('}'), + // We can encounter a `\` as the last character in a `FStringMiddle` + // token which is valid in this context. For example, + // + // ```python + // f"\{foo} \{bar:\}" + // # ^ ^^ ^ + // ``` + // + // Here, the `FStringMiddle` token content will be "\" and " \" + // which is invalid if we look at the content in isolation: + // + // ```python + // "\" + // ``` + // + // However, the content is syntactically valid in the context of + // the f-string because it's a substring of the entire f-string. + // This is still an invalid escape sequence, but we don't want to + // raise a syntax error as is done by the CPython parser. It might + // be supported in the future, refer to point 3: https://peps.python.org/pep-0701/#rejected-ideas + b'\\' => { + if !self.flags.is_raw_string() && self.peek_byte().is_some() { + match self.parse_escaped_char()? { + None => {} + Some(EscapedChar::Literal(c)) => value.push(c), + Some(EscapedChar::Escape(c)) => { + value.push_char('\\'); + value.push_char(c); + } + } + } else { + value.push_char('\\'); + } + } + ch => { + unreachable!("Expected '{{', '}}', or '\\' but got {:?}", ch); + } + } + + let Some(next_index) = + memchr::memchr3(b'{', b'}', b'\\', self.source[self.cursor..].as_bytes()) + else { + // Add the rest of the string to the value. + let rest = &self.source[self.cursor..]; + value.push_str(rest); + break; + }; + + index = next_index; + } + + Ok(value.into()) + } + + fn parse_string(mut self) -> Result, LexicalError> { + if self.flags.is_raw_string() { + // For raw strings, no escaping is necessary. + return Ok(self.source.into()); + } + + let Some(mut escape) = memchr::memchr(b'\\', self.source.as_bytes()) else { + // If the string doesn't contain any escape sequences, return the owned string. + return Ok(self.source.into()); + }; + + // If the string contains escape sequences, we need to parse them. + let mut value = Wtf8Buf::with_capacity(self.source.len()); + + loop { + // Add the characters before the escape sequence to the string. + let before_with_slash = self.skip_bytes(escape + 1); + let before = &before_with_slash[..before_with_slash.len() - 1]; + value.push_str(before); + + // Add the escaped character to the string. + match self.parse_escaped_char()? { + None => {} + Some(EscapedChar::Literal(c)) => value.push(c), + Some(EscapedChar::Escape(c)) => { + value.push_char('\\'); + value.push_char(c); + } + } + + let Some(next_escape) = self.source[self.cursor..].find('\\') else { + // Add the rest of the string to the value. + let rest = &self.source[self.cursor..]; + value.push_str(rest); + break; + }; + + // Update the position of the next escape sequence. + escape = next_escape; + } + + Ok(value.into()) + } +} + +pub(crate) fn parse_string_literal(source: &str, flags: AnyStringFlags) -> Box { + let source = &source[flags.opener_len().to_usize()..]; + let source = &source[..source.len() - flags.quote_len().to_usize()]; + StringParser::new(source.into(), flags) + .parse_string() + .unwrap_or_else(|x| match x {}) +} + +pub(crate) fn parse_fstring_literal_element(source: Box, flags: AnyStringFlags) -> Box { + StringParser::new(source, flags) + .parse_fstring_middle() + .unwrap_or_else(|x| match x {}) +} diff --git a/compiler/core/Cargo.toml b/compiler/core/Cargo.toml index 7621c643d5..8ff0cd020d 100644 --- a/compiler/core/Cargo.toml +++ b/compiler/core/Cargo.toml @@ -13,6 +13,7 @@ license.workspace = true ruff_python_ast = { workspace = true } ruff_python_parser = { workspace = true } ruff_source_file = { workspace = true } +rustpython-common = { workspace = true } bitflags = { workspace = true } itertools = { workspace = true } diff --git a/compiler/core/src/bytecode.rs b/compiler/core/src/bytecode.rs index 4cb80020e7..7b018d1df1 100644 --- a/compiler/core/src/bytecode.rs +++ b/compiler/core/src/bytecode.rs @@ -8,6 +8,7 @@ use num_complex::Complex64; pub use ruff_python_ast::ConversionFlag; // use rustpython_parser_core::source_code::{OneIndexed, SourceLocation}; use ruff_source_file::{OneIndexed, SourceLocation}; +use rustpython_common::wtf8::{Wtf8, Wtf8Buf}; use std::marker::PhantomData; use std::{collections::BTreeSet, fmt, hash, mem}; @@ -678,7 +679,7 @@ pub enum ConstantData { Float { value: f64 }, Complex { value: Complex64 }, Boolean { value: bool }, - Str { value: String }, + Str { value: Wtf8Buf }, Bytes { value: Vec }, Code { code: Box }, None, @@ -738,7 +739,7 @@ pub enum BorrowedConstant<'a, C: Constant> { Float { value: f64 }, Complex { value: Complex64 }, Boolean { value: bool }, - Str { value: &'a str }, + Str { value: &'a Wtf8 }, Bytes { value: &'a [u8] }, Code { code: &'a CodeObject }, Tuple { elements: &'a [C] }, diff --git a/compiler/core/src/marshal.rs b/compiler/core/src/marshal.rs index 1e47a6cac5..0c8da17ff9 100644 --- a/compiler/core/src/marshal.rs +++ b/compiler/core/src/marshal.rs @@ -2,6 +2,7 @@ use crate::bytecode::*; use malachite_bigint::{BigInt, Sign}; use num_complex::Complex64; use ruff_source_file::{OneIndexed, SourceLocation}; +use rustpython_common::wtf8::Wtf8; use std::convert::Infallible; pub const FORMAT_VERSION: u32 = 4; @@ -117,6 +118,9 @@ pub trait Read { fn read_str(&mut self, len: u32) -> Result<&str> { Ok(std::str::from_utf8(self.read_slice(len)?)?) } + fn read_wtf8(&mut self, len: u32) -> Result<&Wtf8> { + Wtf8::from_bytes(self.read_slice(len)?).ok_or(MarshalError::InvalidUtf8) + } fn read_u8(&mut self) -> Result { Ok(u8::from_le_bytes(*self.read_array()?)) } @@ -262,7 +266,7 @@ pub trait MarshalBag: Copy { fn make_ellipsis(&self) -> Self::Value; fn make_float(&self, value: f64) -> Self::Value; fn make_complex(&self, value: Complex64) -> Self::Value; - fn make_str(&self, value: &str) -> Self::Value; + fn make_str(&self, value: &Wtf8) -> Self::Value; fn make_bytes(&self, value: &[u8]) -> Self::Value; fn make_int(&self, value: BigInt) -> Self::Value; fn make_tuple(&self, elements: impl Iterator) -> Self::Value; @@ -299,7 +303,7 @@ impl MarshalBag for Bag { fn make_complex(&self, value: Complex64) -> Self::Value { self.make_constant::(BorrowedConstant::Complex { value }) } - fn make_str(&self, value: &str) -> Self::Value { + fn make_str(&self, value: &Wtf8) -> Self::Value { self.make_constant::(BorrowedConstant::Str { value }) } fn make_bytes(&self, value: &[u8]) -> Self::Value { @@ -368,7 +372,7 @@ pub fn deserialize_value(rdr: &mut R, bag: Bag) -> Res } Type::Ascii | Type::Unicode => { let len = rdr.read_u32()?; - let value = rdr.read_str(len)?; + let value = rdr.read_wtf8(len)?; bag.make_str(value) } Type::Tuple => { @@ -422,7 +426,7 @@ pub enum DumpableValue<'a, D: Dumpable> { Float(f64), Complex(Complex64), Boolean(bool), - Str(&'a str), + Str(&'a Wtf8), Bytes(&'a [u8]), Code(&'a CodeObject), Tuple(&'a [D]), diff --git a/jit/tests/common.rs b/jit/tests/common.rs index a2d4fc3bc1..a4ac8a7967 100644 --- a/jit/tests/common.rs +++ b/jit/tests/common.rs @@ -53,7 +53,9 @@ enum StackValue { impl From for StackValue { fn from(value: ConstantData) -> Self { match value { - ConstantData::Str { value } => StackValue::String(value), + ConstantData::Str { value } => { + StackValue::String(value.into_string().expect("surrogate in test code")) + } ConstantData::None => StackValue::None, ConstantData::Code { code } => StackValue::Code(code), c => unimplemented!("constant {:?} isn't yet supported in py_function!", c), diff --git a/vm/src/builtins/code.rs b/vm/src/builtins/code.rs index ba2d2dd5c3..4bb209f6db 100644 --- a/vm/src/builtins/code.rs +++ b/vm/src/builtins/code.rs @@ -74,7 +74,7 @@ fn borrow_obj_constant(obj: &PyObject) -> BorrowedConstant<'_, Literal> { ref c @ super::complex::PyComplex => BorrowedConstant::Complex { value: c.to_complex() }, - ref s @ super::pystr::PyStr => BorrowedConstant::Str { value: s.as_str() }, + ref s @ super::pystr::PyStr => BorrowedConstant::Str { value: s.as_wtf8() }, ref b @ super::bytes::PyBytes => BorrowedConstant::Bytes { value: b.as_bytes() }, diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 55cefae4f7..8fe3904945 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -1815,6 +1815,18 @@ impl AsRef for PyExact { } } +impl AsRef for PyRefExact { + fn as_ref(&self) -> &Wtf8 { + self.as_wtf8() + } +} + +impl AsRef for PyExact { + fn as_ref(&self) -> &Wtf8 { + self.as_wtf8() + } +} + impl AnyStrWrapper for PyStrRef { fn as_ref(&self) -> Option<&Wtf8> { Some(self.as_wtf8()) diff --git a/vm/src/intern.rs b/vm/src/intern.rs index 10aaa53454..bb9220d069 100644 --- a/vm/src/intern.rs +++ b/vm/src/intern.rs @@ -1,3 +1,5 @@ +use rustpython_common::wtf8::{Wtf8, Wtf8Buf}; + use crate::{ AsObject, Py, PyExact, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, VirtualMachine, builtins::{PyStr, PyStrInterned, PyTypeRef}, @@ -86,29 +88,29 @@ pub struct CachedPyStrRef { impl std::hash::Hash for CachedPyStrRef { fn hash(&self, state: &mut H) { - self.inner.as_str().hash(state) + self.inner.as_wtf8().hash(state) } } impl PartialEq for CachedPyStrRef { fn eq(&self, other: &Self) -> bool { - self.inner.as_str() == other.inner.as_str() + self.inner.as_wtf8() == other.inner.as_wtf8() } } impl Eq for CachedPyStrRef {} -impl std::borrow::Borrow for CachedPyStrRef { +impl std::borrow::Borrow for CachedPyStrRef { #[inline] - fn borrow(&self) -> &str { - self.inner.as_str() + fn borrow(&self) -> &Wtf8 { + self.as_wtf8() } } -impl AsRef for CachedPyStrRef { +impl AsRef for CachedPyStrRef { #[inline] - fn as_ref(&self) -> &str { - self.as_str() + fn as_ref(&self) -> &Wtf8 { + self.as_wtf8() } } @@ -121,8 +123,8 @@ impl CachedPyStrRef { } #[inline] - fn as_str(&self) -> &str { - self.inner.as_str() + fn as_wtf8(&self) -> &Wtf8 { + self.inner.as_wtf8() } } @@ -209,6 +211,8 @@ impl ToPyObject for &'static PyInterned { } mod sealed { + use rustpython_common::wtf8::{Wtf8, Wtf8Buf}; + use crate::{ builtins::PyStr, object::{Py, PyExact, PyRefExact}, @@ -218,11 +222,14 @@ mod sealed { impl SealedInternable for String {} impl SealedInternable for &str {} + impl SealedInternable for Wtf8Buf {} + impl SealedInternable for &Wtf8 {} impl SealedInternable for PyRefExact {} pub trait SealedMaybeInterned {} impl SealedMaybeInterned for str {} + impl SealedMaybeInterned for Wtf8 {} impl SealedMaybeInterned for PyExact {} impl SealedMaybeInterned for Py {} } @@ -250,6 +257,21 @@ impl InternableString for &str { } } +impl InternableString for Wtf8Buf { + type Interned = Wtf8; + fn into_pyref_exact(self, str_type: PyTypeRef) -> PyRefExact { + let obj = PyRef::new_ref(PyStr::from(self), str_type, None); + unsafe { PyRefExact::new_unchecked(obj) } + } +} + +impl InternableString for &Wtf8 { + type Interned = Wtf8; + fn into_pyref_exact(self, str_type: PyTypeRef) -> PyRefExact { + self.to_owned().into_pyref_exact(str_type) + } +} + impl InternableString for PyRefExact { type Interned = Py; #[inline] @@ -259,7 +281,7 @@ impl InternableString for PyRefExact { } pub trait MaybeInternedString: - AsRef + crate::dictdatatype::DictKey + sealed::SealedMaybeInterned + AsRef + crate::dictdatatype::DictKey + sealed::SealedMaybeInterned { fn as_interned(&self) -> Option<&'static PyStrInterned>; } @@ -271,6 +293,13 @@ impl MaybeInternedString for str { } } +impl MaybeInternedString for Wtf8 { + #[inline(always)] + fn as_interned(&self) -> Option<&'static PyStrInterned> { + None + } +} + impl MaybeInternedString for PyExact { #[inline(always)] fn as_interned(&self) -> Option<&'static PyStrInterned> { @@ -296,7 +325,7 @@ impl PyObject { if self.is_interned() { s.unwrap().as_interned() } else if let Some(s) = s { - vm.ctx.interned_str(s.as_str()) + vm.ctx.interned_str(s.as_wtf8()) } else { None } diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index fd7332e7c2..564ee5bf6c 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -10,6 +10,7 @@ mod decl { PyBool, PyByteArray, PyBytes, PyCode, PyComplex, PyDict, PyEllipsis, PyFloat, PyFrozenSet, PyInt, PyList, PyNone, PySet, PyStopIteration, PyStr, PyTuple, }, + common::wtf8::Wtf8, convert::ToPyObject, function::{ArgBytesLike, OptionalArg}, object::AsObject, @@ -53,7 +54,7 @@ mod decl { f(Complex(pycomplex.to_complex64())) } ref pystr @ PyStr => { - f(Str(pystr.as_str())) + f(Str(pystr.as_wtf8())) } ref pylist @ PyList => { f(List(&pylist.borrow_vec())) @@ -139,7 +140,7 @@ mod decl { fn make_complex(&self, value: Complex64) -> Self::Value { self.0.ctx.new_complex(value).into() } - fn make_str(&self, value: &str) -> Self::Value { + fn make_str(&self, value: &Wtf8) -> Self::Value { self.0.ctx.new_str(value).into() } fn make_bytes(&self, value: &[u8]) -> Self::Value { From 0a07cd931f5437f3f24201ee2dec6e4e431e52f6 Mon Sep 17 00:00:00 2001 From: Noa Date: Wed, 26 Mar 2025 20:37:26 -0500 Subject: [PATCH 2/3] Fix more surrogate crashes --- Lib/test/test_codecs.py | 31 ++++++------ Lib/test/test_json/test_scanstring.py | 2 - Lib/test/test_regrtest.py | 2 - Lib/test/test_stringprep.py | 2 - Lib/test/test_subprocess.py | 2 - Lib/test/test_tarfile.py | 14 ----- Lib/test/test_unicode.py | 8 --- Lib/test/test_userstring.py | 4 -- Lib/test/test_zipimport.py | 1 + common/src/wtf8/mod.rs | 55 ++++++++++++-------- stdlib/src/json.rs | 5 +- stdlib/src/json/machinery.rs | 73 +++++++++++++-------------- vm/src/builtins/complex.rs | 9 ++-- vm/src/builtins/float.rs | 2 +- vm/src/builtins/int.rs | 2 +- vm/src/builtins/str.rs | 32 +++++++++--- vm/src/builtins/type.rs | 2 +- vm/src/protocol/number.rs | 2 +- vm/src/protocol/object.rs | 2 +- vm/src/stdlib/codecs.rs | 2 +- vm/src/stdlib/io.rs | 16 +++--- vm/src/stdlib/time.rs | 8 ++- vm/src/utils.rs | 6 ++- 23 files changed, 142 insertions(+), 140 deletions(-) diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index df04653c66..a12e5893dc 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -869,6 +869,11 @@ def test_bug691291(self): with reader: self.assertEqual(reader.read(), s1) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_incremental_surrogatepass(self): + super().test_incremental_surrogatepass() + class UTF16LETest(ReadTest, unittest.TestCase): encoding = "utf-16-le" ill_formed_sequence = b"\x80\xdc" @@ -917,6 +922,11 @@ def test_nonbmp(self): self.assertEqual(b'\x00\xd8\x03\xde'.decode(self.encoding), "\U00010203") + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_incremental_surrogatepass(self): + super().test_incremental_surrogatepass() + class UTF16BETest(ReadTest, unittest.TestCase): encoding = "utf-16-be" ill_formed_sequence = b"\xdc\x80" @@ -965,6 +975,11 @@ def test_nonbmp(self): self.assertEqual(b'\xd8\x00\xde\x03'.decode(self.encoding), "\U00010203") + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_incremental_surrogatepass(self): + super().test_incremental_surrogatepass() + class UTF8Test(ReadTest, unittest.TestCase): encoding = "utf-8" ill_formed_sequence = b"\xed\xb2\x80" @@ -998,8 +1013,6 @@ def test_decoder_state(self): self.check_state_handling_decode(self.encoding, u, u.encode(self.encoding)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decode_error(self): for data, error_handler, expected in ( (b'[\x80\xff]', 'ignore', '[]'), @@ -1026,8 +1039,6 @@ def test_lone_surrogates(self): exc = cm.exception self.assertEqual(exc.object[exc.start:exc.end], '\uD800\uDFFF') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_surrogatepass_handler(self): self.assertEqual("abc\ud800def".encode(self.encoding, "surrogatepass"), self.BOM + b"abc\xed\xa0\x80def") @@ -2884,8 +2895,6 @@ def test_escape_encode(self): class SurrogateEscapeTest(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_utf8(self): # Bad byte self.assertEqual(b"foo\x80bar".decode("utf-8", "surrogateescape"), @@ -2898,8 +2907,6 @@ def test_utf8(self): self.assertEqual("\udced\udcb0\udc80".encode("utf-8", "surrogateescape"), b"\xed\xb0\x80") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_ascii(self): # bad byte self.assertEqual(b"foo\x80bar".decode("ascii", "surrogateescape"), @@ -2916,8 +2923,6 @@ def test_charmap(self): self.assertEqual("foo\udca5bar".encode("iso-8859-3", "surrogateescape"), b"foo\xa5bar") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_latin1(self): # Issue6373 self.assertEqual("\udce4\udceb\udcef\udcf6\udcfc".encode("latin-1", "surrogateescape"), @@ -3561,8 +3566,6 @@ class ASCIITest(unittest.TestCase): def test_encode(self): self.assertEqual('abc123'.encode('ascii'), b'abc123') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode_error(self): for data, error_handler, expected in ( ('[\x80\xff\u20ac]', 'ignore', b'[]'), @@ -3585,8 +3588,6 @@ def test_encode_surrogateescape_error(self): def test_decode(self): self.assertEqual(b'abc'.decode('ascii'), 'abc') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decode_error(self): for data, error_handler, expected in ( (b'[\x80\xff]', 'ignore', '[]'), @@ -3609,8 +3610,6 @@ def test_encode(self): with self.subTest(data=data, expected=expected): self.assertEqual(data.encode('latin1'), expected) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode_errors(self): for data, error_handler, expected in ( ('[\u20ac\udc80]', 'ignore', b'[]'), diff --git a/Lib/test/test_json/test_scanstring.py b/Lib/test/test_json/test_scanstring.py index 682dc74999..af4bb3a639 100644 --- a/Lib/test/test_json/test_scanstring.py +++ b/Lib/test/test_json/test_scanstring.py @@ -86,8 +86,6 @@ def test_scanstring(self): scanstring('["Bad value", truth]', 2, True), ('Bad value', 12)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_surrogates(self): scanstring = self.json.decoder.scanstring def assertScan(given, expect): diff --git a/Lib/test/test_regrtest.py b/Lib/test/test_regrtest.py index aab2d9f7ae..fc56ec4afc 100644 --- a/Lib/test/test_regrtest.py +++ b/Lib/test/test_regrtest.py @@ -945,7 +945,6 @@ def test_leak(self): """) self.check_leak(code, 'file descriptors') - @unittest.expectedFailureIfWindows('TODO: RUSTPYTHON Windows') def test_list_tests(self): # test --list-tests tests = [self.create_test() for i in range(5)] @@ -953,7 +952,6 @@ def test_list_tests(self): self.assertEqual(output.rstrip().splitlines(), tests) - @unittest.expectedFailureIfWindows('TODO: RUSTPYTHON Windows') def test_list_cases(self): # test --list-cases code = textwrap.dedent(""" diff --git a/Lib/test/test_stringprep.py b/Lib/test/test_stringprep.py index 118f3f0867..d4b4a13d0d 100644 --- a/Lib/test/test_stringprep.py +++ b/Lib/test/test_stringprep.py @@ -6,8 +6,6 @@ from stringprep import * class StringprepTests(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test(self): self.assertTrue(in_table_a1("\u0221")) self.assertFalse(in_table_a1("\u0222")) diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py index d7507eb7f0..e5b18fe20f 100644 --- a/Lib/test/test_subprocess.py +++ b/Lib/test/test_subprocess.py @@ -1198,8 +1198,6 @@ def test_universal_newlines_communicate_encodings(self): stdout, stderr = popen.communicate(input='') self.assertEqual(stdout, '1\n2\n3\n4') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_communicate_errors(self): for errors, expected in [ ('ignore', ''), diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index 4ae81cb99f..63f7b347ad 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -2086,11 +2086,6 @@ class UstarUnicodeTest(UnicodeTest, unittest.TestCase): format = tarfile.USTAR_FORMAT - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_uname_unicode(self): - super().test_uname_unicode() - # Test whether the utf-8 encoded version of a filename exceeds the 100 # bytes name field limit (every occurrence of '\xff' will be expanded to 2 # bytes). @@ -2170,13 +2165,6 @@ class GNUUnicodeTest(UnicodeTest, unittest.TestCase): format = tarfile.GNU_FORMAT - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_uname_unicode(self): - super().test_uname_unicode() - - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bad_pax_header(self): # Test for issue #8633. GNU tar <= 1.23 creates raw binary fields # without a hdrcharset=BINARY header. @@ -2198,8 +2186,6 @@ class PAXUnicodeTest(UnicodeTest, unittest.TestCase): # PAX_FORMAT ignores encoding in write mode. test_unicode_filename_error = None - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_binary_header(self): # Test a POSIX.1-2008 compatible header with a hdrcharset=BINARY field. for encoding, name in ( diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 5c2c6c29b1..4da63c54d4 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -608,8 +608,6 @@ def test_bytes_comparison(self): self.assertEqual('abc' == bytearray(b'abc'), False) self.assertEqual('abc' != bytearray(b'abc'), True) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_comparison(self): # Comparisons: self.assertEqual('abc', 'abc') @@ -830,8 +828,6 @@ def test_isidentifier_legacy(self): warnings.simplefilter('ignore', DeprecationWarning) self.assertTrue(_testcapi.unicode_legacy_string(u).isidentifier()) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_isprintable(self): self.assertTrue("".isprintable()) self.assertTrue(" ".isprintable()) @@ -847,8 +843,6 @@ def test_isprintable(self): self.assertTrue('\U0001F46F'.isprintable()) self.assertFalse('\U000E0020'.isprintable()) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_surrogates(self): for s in ('a\uD800b\uDFFF', 'a\uDFFFb\uD800', 'a\uD800b\uDFFFa', 'a\uDFFFb\uD800a'): @@ -1827,8 +1821,6 @@ def test_codecs_utf7(self): 'ill-formed sequence'): b'+@'.decode('utf-7') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_codecs_utf8(self): self.assertEqual(''.encode('utf-8'), b'') self.assertEqual('\u20ac'.encode('utf-8'), b'\xe2\x82\xac') diff --git a/Lib/test/test_userstring.py b/Lib/test/test_userstring.py index c0017794e8..51b4f6041e 100644 --- a/Lib/test/test_userstring.py +++ b/Lib/test/test_userstring.py @@ -53,8 +53,6 @@ def __rmod__(self, other): str3 = ustr3('TEST') self.assertEqual(fmt2 % str3, 'value is TEST') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode_default_args(self): self.checkequal(b'hello', 'hello', 'encode') # Check that encoding defaults to utf-8 @@ -62,8 +60,6 @@ def test_encode_default_args(self): # Check that errors defaults to 'strict' self.checkraises(UnicodeError, '\ud800', 'encode') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode_explicit_none_args(self): self.checkequal(b'hello', 'hello', 'encode', None, None) # Check that encoding defaults to utf-8 diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py index b291d53016..488a67e80f 100644 --- a/Lib/test/test_zipimport.py +++ b/Lib/test/test_zipimport.py @@ -730,6 +730,7 @@ def testTraceback(self): @unittest.skipIf(os_helper.TESTFN_UNENCODABLE is None, "need an unencodable filename") + @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def testUnencodable(self): filename = os_helper.TESTFN_UNENCODABLE + ".zip" self.addCleanup(os_helper.unlink, filename) diff --git a/common/src/wtf8/mod.rs b/common/src/wtf8/mod.rs index 21c5de28bd..57a1d6d7de 100644 --- a/common/src/wtf8/mod.rs +++ b/common/src/wtf8/mod.rs @@ -122,18 +122,18 @@ impl CodePoint { /// Returns the numeric value of the code point if it is a leading surrogate. #[inline] - pub fn to_lead_surrogate(self) -> Option { + pub fn to_lead_surrogate(self) -> Option { match self.value { - lead @ 0xD800..=0xDBFF => Some(lead as u16), + lead @ 0xD800..=0xDBFF => Some(LeadSurrogate(lead as u16)), _ => None, } } /// Returns the numeric value of the code point if it is a trailing surrogate. #[inline] - pub fn to_trail_surrogate(self) -> Option { + pub fn to_trail_surrogate(self) -> Option { match self.value { - trail @ 0xDC00..=0xDFFF => Some(trail as u16), + trail @ 0xDC00..=0xDFFF => Some(TrailSurrogate(trail as u16)), _ => None, } } @@ -216,6 +216,18 @@ impl PartialEq for char { } } +#[derive(Clone, Copy)] +pub struct LeadSurrogate(u16); + +#[derive(Clone, Copy)] +pub struct TrailSurrogate(u16); + +impl LeadSurrogate { + pub fn merge(self, trail: TrailSurrogate) -> char { + decode_surrogate_pair(self.0, trail.0) + } +} + /// An owned, growable string of well-formed WTF-8 data. /// /// Similar to `String`, but can additionally contain surrogate code points @@ -291,6 +303,14 @@ impl Wtf8Buf { Wtf8Buf { bytes: value } } + /// Create a WTF-8 string from a WTF-8 byte vec. + pub fn from_bytes(value: Vec) -> Result> { + match Wtf8::from_bytes(&value) { + Some(_) => Ok(unsafe { Self::from_bytes_unchecked(value) }), + None => Err(value), + } + } + /// Creates a WTF-8 string from a UTF-8 `String`. /// /// This takes ownership of the `String` and does not copy. @@ -750,15 +770,10 @@ impl Wtf8 { } fn decode_surrogate(b: &[u8]) -> Option { - let [a, b, c, ..] = *b else { return None }; - if (a & 0xf0) == 0xe0 && (b & 0xc0) == 0x80 && (c & 0xc0) == 0x80 { - // it's a three-byte code - let c = ((a as u32 & 0x0f) << 12) + ((b as u32 & 0x3f) << 6) + (c as u32 & 0x3f); - let 0xD800..=0xDFFF = c else { return None }; - Some(CodePoint { value: c }) - } else { - None - } + let [0xed, b2 @ (0xa0..), b3, ..] = *b else { + return None; + }; + Some(decode_surrogate(b2, b3).into()) } /// Returns the length, in WTF-8 bytes. @@ -914,14 +929,6 @@ impl Wtf8 { } } - #[inline] - fn final_lead_surrogate(&self) -> Option { - match self.bytes { - [.., 0xED, b2 @ 0xA0..=0xAF, b3] => Some(decode_surrogate(b2, b3)), - _ => None, - } - } - pub fn is_code_point_boundary(&self, index: usize) -> bool { is_code_point_boundary(self, index) } @@ -1222,6 +1229,12 @@ fn decode_surrogate(second_byte: u8, third_byte: u8) -> u16 { 0xD800 | (second_byte as u16 & 0x3F) << 6 | third_byte as u16 & 0x3F } +#[inline] +fn decode_surrogate_pair(lead: u16, trail: u16) -> char { + let code_point = 0x10000 + ((((lead - 0xD800) as u32) << 10) | (trail - 0xDC00) as u32); + unsafe { char::from_u32_unchecked(code_point) } +} + /// Copied from str::is_char_boundary #[inline] fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool { diff --git a/stdlib/src/json.rs b/stdlib/src/json.rs index aaac0b8bef..f970ef5dc2 100644 --- a/stdlib/src/json.rs +++ b/stdlib/src/json.rs @@ -13,6 +13,7 @@ mod _json { types::{Callable, Constructor}, }; use malachite_bigint::BigInt; + use rustpython_common::wtf8::Wtf8Buf; use std::str::FromStr; #[pyattr(name = "make_scanner")] @@ -253,8 +254,8 @@ mod _json { end: usize, strict: OptionalArg, vm: &VirtualMachine, - ) -> PyResult<(String, usize)> { - machinery::scanstring(s.as_str(), end, strict.unwrap_or(true)) + ) -> PyResult<(Wtf8Buf, usize)> { + machinery::scanstring(s.as_wtf8(), end, strict.unwrap_or(true)) .map_err(|e| py_decode_error(e, s, vm)) } } diff --git a/stdlib/src/json/machinery.rs b/stdlib/src/json/machinery.rs index 0614314f4f..4612b5263d 100644 --- a/stdlib/src/json/machinery.rs +++ b/stdlib/src/json/machinery.rs @@ -28,6 +28,9 @@ use std::io; +use itertools::Itertools; +use rustpython_common::wtf8::{CodePoint, Wtf8, Wtf8Buf}; + static ESCAPE_CHARS: [&str; 0x20] = [ "\\u0000", "\\u0001", "\\u0002", "\\u0003", "\\u0004", "\\u0005", "\\u0006", "\\u0007", "\\b", "\\t", "\\n", "\\u000", "\\f", "\\r", "\\u000e", "\\u000f", "\\u0010", "\\u0011", "\\u0012", @@ -111,22 +114,22 @@ impl DecodeError { } enum StrOrChar<'a> { - Str(&'a str), - Char(char), + Str(&'a Wtf8), + Char(CodePoint), } impl StrOrChar<'_> { fn len(&self) -> usize { match self { StrOrChar::Str(s) => s.len(), - StrOrChar::Char(c) => c.len_utf8(), + StrOrChar::Char(c) => c.len_wtf8(), } } } pub fn scanstring<'a>( - s: &'a str, + s: &'a Wtf8, end: usize, strict: bool, -) -> Result<(String, usize), DecodeError> { +) -> Result<(Wtf8Buf, usize), DecodeError> { let mut chunks: Vec> = Vec::new(); let mut output_len = 0usize; let mut push_chunk = |chunk: StrOrChar<'a>| { @@ -134,16 +137,16 @@ pub fn scanstring<'a>( chunks.push(chunk); }; let unterminated_err = || DecodeError::new("Unterminated string starting at", end - 1); - let mut chars = s.char_indices().enumerate().skip(end).peekable(); + let mut chars = s.code_point_indices().enumerate().skip(end).peekable(); let &(_, (mut chunk_start, _)) = chars.peek().ok_or_else(unterminated_err)?; while let Some((char_i, (byte_i, c))) = chars.next() { - match c { + match c.to_char_lossy() { '"' => { push_chunk(StrOrChar::Str(&s[chunk_start..byte_i])); - let mut out = String::with_capacity(output_len); + let mut out = Wtf8Buf::with_capacity(output_len); for x in chunks { match x { - StrOrChar::Str(s) => out.push_str(s), + StrOrChar::Str(s) => out.push_wtf8(s), StrOrChar::Char(c) => out.push(c), } } @@ -152,7 +155,7 @@ pub fn scanstring<'a>( '\\' => { push_chunk(StrOrChar::Str(&s[chunk_start..byte_i])); let (_, (_, c)) = chars.next().ok_or_else(unterminated_err)?; - let esc = match c { + let esc = match c.to_char_lossy() { '"' => "\"", '\\' => "\\", '/' => "/", @@ -162,41 +165,33 @@ pub fn scanstring<'a>( 'r' => "\r", 't' => "\t", 'u' => { - let surrogate_err = || DecodeError::new("unpaired surrogate", char_i); let mut uni = decode_unicode(&mut chars, char_i)?; chunk_start = byte_i + 6; - if (0xd800..=0xdbff).contains(&uni) { + if let Some(lead) = uni.to_lead_surrogate() { // uni is a surrogate -- try to find its pair - if let Some(&(pos2, (_, '\\'))) = chars.peek() { - // ok, the next char starts an escape - chars.next(); - if let Some((_, (_, 'u'))) = chars.peek() { - // ok, it's a unicode escape - chars.next(); - let uni2 = decode_unicode(&mut chars, pos2)?; + let mut chars2 = chars.clone(); + if let Some(((pos2, _), (_, _))) = chars2 + .next_tuple() + .filter(|((_, (_, c1)), (_, (_, c2)))| *c1 == '\\' && *c2 == 'u') + { + let uni2 = decode_unicode(&mut chars2, pos2)?; + if let Some(trail) = uni2.to_trail_surrogate() { + // ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates + uni = lead.merge(trail).into(); chunk_start = pos2 + 6; - if (0xdc00..=0xdfff).contains(&uni2) { - // ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates - uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00)); - } else { - // if we don't find a matching surrogate, error -- until str - // isn't utf8 internally, we can't parse surrogates - return Err(surrogate_err()); - } - } else { - return Err(surrogate_err()); + chars = chars2; } } } - push_chunk(StrOrChar::Char( - std::char::from_u32(uni).ok_or_else(surrogate_err)?, - )); + push_chunk(StrOrChar::Char(uni)); continue; } - _ => return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i)), + _ => { + return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i)); + } }; chunk_start = byte_i + 2; - push_chunk(StrOrChar::Str(esc)); + push_chunk(StrOrChar::Str(esc.as_ref())); } '\x00'..='\x1f' if strict => { return Err(DecodeError::new( @@ -211,16 +206,16 @@ pub fn scanstring<'a>( } #[inline] -fn decode_unicode(it: &mut I, pos: usize) -> Result +fn decode_unicode(it: &mut I, pos: usize) -> Result where - I: Iterator, + I: Iterator, { let err = || DecodeError::new("Invalid \\uXXXX escape", pos); let mut uni = 0; for x in (0..4).rev() { let (_, (_, c)) = it.next().ok_or_else(err)?; - let d = c.to_digit(16).ok_or_else(err)?; - uni += d * 16u32.pow(x); + let d = c.to_char().and_then(|c| c.to_digit(16)).ok_or_else(err)? as u16; + uni += d * 16u16.pow(x); } - Ok(uni) + Ok(uni.into()) } diff --git a/vm/src/builtins/complex.rs b/vm/src/builtins/complex.rs index e665d1e27a..01dd65f519 100644 --- a/vm/src/builtins/complex.rs +++ b/vm/src/builtins/complex.rs @@ -179,9 +179,12 @@ impl Constructor for PyComplex { "complex() can't take second arg if first is a string".to_owned(), )); } - let value = parse_str(s.as_str().trim()).ok_or_else(|| { - vm.new_value_error("complex() arg is a malformed string".to_owned()) - })?; + let value = s + .to_str() + .and_then(|s| parse_str(s.trim())) + .ok_or_else(|| { + vm.new_value_error("complex() arg is a malformed string".to_owned()) + })?; return Self::from(value) .into_ref_with_type(vm, cls) .map(Into::into); diff --git a/vm/src/builtins/float.rs b/vm/src/builtins/float.rs index b4601fbb92..48ccd2c437 100644 --- a/vm/src/builtins/float.rs +++ b/vm/src/builtins/float.rs @@ -161,7 +161,7 @@ impl Constructor for PyFloat { fn float_from_string(val: PyObjectRef, vm: &VirtualMachine) -> PyResult { let (bytearray, buffer, buffer_lock); let b = if let Some(s) = val.payload_if_subclass::(vm) { - s.as_str().trim().as_bytes() + s.as_wtf8().trim().as_bytes() } else if let Some(bytes) = val.payload_if_subclass::(vm) { bytes.as_bytes() } else if let Some(buf) = val.payload_if_subclass::(vm) { diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index aa9613e9d7..f457bf5ed8 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -847,7 +847,7 @@ fn try_int_radix(obj: &PyObject, base: u32, vm: &VirtualMachine) -> PyResult { - let s = string.as_str().trim(); + let s = string.as_wtf8().trim(); bytes_to_int(s.as_bytes(), base) } bytes @ PyBytes => { diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 8fe3904945..dfb9de9ba6 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -424,6 +424,23 @@ impl PyStr { self.data.as_str() } + pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> { + self.to_str().ok_or_else(|| { + let start = self + .as_wtf8() + .code_points() + .position(|c| c.to_char().is_none()) + .unwrap(); + vm.new_unicode_encode_error_real( + identifier!(vm, utf_8).to_owned(), + vm.ctx.new_str(self.data.clone()), + start, + start + 1, + vm.ctx.new_str("surrogates not allowed"), + ) + }) + } + pub fn to_string_lossy(&self) -> Cow<'_, str> { self.to_str() .map(Cow::Borrowed) @@ -850,9 +867,9 @@ impl PyStr { /// If the string starts with the prefix string, return string[len(prefix):] /// Otherwise, return a copy of the original string. #[pymethod] - fn removeprefix(&self, pref: PyStrRef) -> String { - self.as_str() - .py_removeprefix(pref.as_str(), pref.byte_len(), |s, p| s.starts_with(p)) + fn removeprefix(&self, pref: PyStrRef) -> Wtf8Buf { + self.as_wtf8() + .py_removeprefix(pref.as_wtf8(), pref.byte_len(), |s, p| s.starts_with(p)) .to_owned() } @@ -861,9 +878,9 @@ impl PyStr { /// If the string ends with the suffix string, return string[:len(suffix)] /// Otherwise, return a copy of the original string. #[pymethod] - fn removesuffix(&self, suffix: PyStrRef) -> String { - self.as_str() - .py_removesuffix(suffix.as_str(), suffix.byte_len(), |s, p| s.ends_with(p)) + fn removesuffix(&self, suffix: PyStrRef) -> Wtf8Buf { + self.as_wtf8() + .py_removesuffix(suffix.as_wtf8(), suffix.byte_len(), |s, p| s.ends_with(p)) .to_owned() } @@ -1294,7 +1311,8 @@ impl PyStr { #[pymethod] fn isidentifier(&self) -> bool { - let mut chars = self.as_str().chars(); + let Some(s) = self.to_str() else { return false }; + let mut chars = s.chars(); let is_identifier_start = chars.next().is_some_and(|c| c == '_' || is_xid_start(c)); // a string is not an identifier if it has whitespace or starts with a number is_identifier_start && chars.all(is_xid_continue) diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 969d6db937..776c777cb3 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -884,7 +884,7 @@ impl Constructor for PyType { attributes .entry(identifier!(vm, __qualname__)) - .or_insert_with(|| vm.ctx.new_str(name.as_str()).into()); + .or_insert_with(|| name.clone().into()); if attributes.get(identifier!(vm, __eq__)).is_some() && attributes.get(identifier!(vm, __hash__)).is_none() diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 2b6720e843..dd039c2733 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -77,7 +77,7 @@ impl PyObject { )) }) } else if let Some(s) = self.payload::() { - try_convert(self, s.as_str().trim().as_bytes(), vm) + try_convert(self, s.as_wtf8().trim().as_bytes(), vm) } else if let Some(bytes) = self.payload::() { try_convert(self, bytes, vm) } else if let Some(bytearray) = self.payload::() { diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 4e69cf38a2..4cdcb68257 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -224,7 +224,7 @@ impl PyObject { dict: Option, vm: &VirtualMachine, ) -> PyResult> { - let name = name_str.as_str(); + let name = name_str.as_wtf8(); let obj_cls = self.class(); let cls_attr_name = vm.ctx.interned_str(name_str); let cls_attr = match cls_attr_name.and_then(|name| obj_cls.get_attr(name)) { diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index 6ad2a74f4b..320d839682 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -26,7 +26,7 @@ mod _codecs { fn lookup(encoding: PyStrRef, vm: &VirtualMachine) -> PyResult { vm.state .codec_registry - .lookup(encoding.as_str(), vm) + .lookup(encoding.try_to_str(vm)?, vm) .map(|codec| codec.into_tuple().into()) } diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 77d9231724..4cf3c058df 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -2245,7 +2245,7 @@ mod _io { let newline = args.newline.unwrap_or_default(); let (encoder, decoder) = - Self::find_coder(&buffer, encoding.as_str(), &errors, newline, vm)?; + Self::find_coder(&buffer, encoding.try_to_str(vm)?, &errors, newline, vm)?; *data = Some(TextIOData { buffer, @@ -2345,7 +2345,7 @@ mod _io { if let Some(encoding) = args.encoding { let (encoder, decoder) = Self::find_coder( &data.buffer, - encoding.as_str(), + encoding.try_to_str(vm)?, &data.errors, data.newline, vm, @@ -3468,9 +3468,9 @@ mod _io { // return the entire contents of the underlying #[pymethod] - fn getvalue(&self, vm: &VirtualMachine) -> PyResult { + fn getvalue(&self, vm: &VirtualMachine) -> PyResult { let bytes = self.buffer(vm)?.getvalue(); - String::from_utf8(bytes) + Wtf8Buf::from_bytes(bytes) .map_err(|_| vm.new_value_error("Error Retrieving Value".to_owned())) } @@ -3491,10 +3491,10 @@ mod _io { // If k is undefined || k == -1, then we read all bytes until the end of the file. // This also increments the stream position by the value of k #[pymethod] - fn read(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult { + fn read(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult { let data = self.buffer(vm)?.read(size.to_usize()).unwrap_or_default(); - let value = String::from_utf8(data) + let value = Wtf8Buf::from_bytes(data) .map_err(|_| vm.new_value_error("Error Retrieving Value".to_owned()))?; Ok(value) } @@ -3505,11 +3505,11 @@ mod _io { } #[pymethod] - fn readline(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult { + fn readline(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult { // TODO size should correspond to the number of characters, at the moments its the number of // bytes. let input = self.buffer(vm)?.readline(size.to_usize(), vm)?; - String::from_utf8(input) + Wtf8Buf::from_bytes(input) .map_err(|_| vm.new_value_error("Error Retrieving Value".to_owned())) } diff --git a/vm/src/stdlib/time.rs b/vm/src/stdlib/time.rs index c464dc3abf..f98530e845 100644 --- a/vm/src/stdlib/time.rs +++ b/vm/src/stdlib/time.rs @@ -327,8 +327,12 @@ mod decl { * raises an error if unsupported format is supplied. * If error happens, we set result as input arg. */ - write!(&mut formatted_time, "{}", instant.format(format.as_str())) - .unwrap_or_else(|_| formatted_time = format.to_string()); + write!( + &mut formatted_time, + "{}", + instant.format(format.try_to_str(vm)?) + ) + .unwrap_or_else(|_| formatted_time = format.to_string()); Ok(vm.ctx.new_str(formatted_time).into()) } diff --git a/vm/src/utils.rs b/vm/src/utils.rs index e2bc993686..78edfb71cc 100644 --- a/vm/src/utils.rs +++ b/vm/src/utils.rs @@ -1,3 +1,5 @@ +use rustpython_common::wtf8::Wtf8; + use crate::{ PyObjectRef, PyResult, VirtualMachine, builtins::PyStr, @@ -18,9 +20,9 @@ impl ToPyObject for std::convert::Infallible { } } -pub trait ToCString: AsRef { +pub trait ToCString: AsRef { fn to_cstring(&self, vm: &VirtualMachine) -> PyResult { - std::ffi::CString::new(self.as_ref()).map_err(|err| err.to_pyexception(vm)) + std::ffi::CString::new(self.as_ref().as_bytes()).map_err(|err| err.to_pyexception(vm)) } fn ensure_no_nul(&self, vm: &VirtualMachine) -> PyResult<()> { if self.as_ref().as_bytes().contains(&b'\0') { From dd467f6c73bc31e04c11f990b919014e298ee631 Mon Sep 17 00:00:00 2001 From: Noa Date: Thu, 27 Mar 2025 10:15:18 -0500 Subject: [PATCH 3/3] Update common/src/wtf8/mod.rs Co-authored-by: Jeong, YunWon <69878+youknowone@users.noreply.github.com> --- common/src/wtf8/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/src/wtf8/mod.rs b/common/src/wtf8/mod.rs index 57a1d6d7de..ff4dcf8900 100644 --- a/common/src/wtf8/mod.rs +++ b/common/src/wtf8/mod.rs @@ -763,7 +763,7 @@ impl Wtf8 { let mut rest = b; while let Err(e) = std::str::from_utf8(rest) { rest = &rest[e.valid_up_to()..]; - Self::decode_surrogate(rest)?; + let _ = Self::decode_surrogate(rest)?; rest = &rest[3..]; } Some(unsafe { Wtf8::from_bytes_unchecked(b) })