diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 4b1231f..5c09328 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -22,8 +22,8 @@ jobs: - name: Setup run: | - rustup toolchain install nightly-2024-05-18-x86_64-unknown-linux-gnu - rustup component add rust-src --toolchain nightly-2024-05-18-x86_64-unknown-linux-gnu + rustup toolchain install nightly-2025-04-15-x86_64-unknown-linux-gnu + rustup component add rust-src --toolchain nightly-2025-04-15-x86_64-unknown-linux-gnu rustup target add \ aarch64-linux-android \ armv7-linux-androideabi \ diff --git a/.github/workflows/ios.yml b/.github/workflows/ios.yml index aebd68e..4c6057e 100644 --- a/.github/workflows/ios.yml +++ b/.github/workflows/ios.yml @@ -14,8 +14,8 @@ jobs: - name: Setup run: | - rustup toolchain install nightly-2024-05-18-aarch64-apple-darwin - rustup component add rust-src --toolchain nightly-2024-05-18-aarch64-apple-darwin + rustup toolchain install nightly-2025-04-15-aarch64-apple-darwin + rustup component add rust-src --toolchain nightly-2025-04-15-aarch64-apple-darwin rustup target add \ x86_64-apple-darwin \ aarch64-apple-darwin \ diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index a91d5bf..0e8d7cb 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -15,7 +15,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binaries @@ -33,7 +33,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binaries diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 67d24fa..a59e0cd 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -15,7 +15,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binary @@ -33,7 +33,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binary diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 36028f3..9c9564b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -52,8 +52,8 @@ jobs: - name: Setup run: | - rustup toolchain install nightly-2024-05-18-x86_64-unknown-linux-gnu - rustup component add rust-src --toolchain nightly-2024-05-18-x86_64-unknown-linux-gnu + rustup toolchain install nightly-2025-04-15-x86_64-unknown-linux-gnu + rustup component add rust-src --toolchain nightly-2025-04-15-x86_64-unknown-linux-gnu rustup target add \ aarch64-linux-android \ armv7-linux-androideabi \ @@ -84,8 +84,8 @@ jobs: - name: Setup run: | - rustup toolchain install nightly-2024-05-18-aarch64-apple-darwin - rustup component add rust-src --toolchain nightly-2024-05-18-aarch64-apple-darwin + rustup toolchain install nightly-2025-04-15-aarch64-apple-darwin + rustup component add rust-src --toolchain nightly-2025-04-15-aarch64-apple-darwin rustup target add \ x86_64-apple-darwin \ aarch64-apple-darwin \ @@ -153,7 +153,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binaries @@ -178,7 +178,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binaries @@ -203,7 +203,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binary @@ -228,7 +228,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binary @@ -253,7 +253,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binary @@ -278,7 +278,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Setup emsdk diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 20ec38d..3d771bd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build diff --git a/.github/workflows/wasm.yml b/.github/workflows/wasm.yml index eab0977..a37d3ab 100644 --- a/.github/workflows/wasm.yml +++ b/.github/workflows/wasm.yml @@ -8,20 +8,20 @@ jobs: if: github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Setup emsdk uses: mymindstorm/setup-emsdk@v14 with: - version: 3.1.68 + version: 4.0.7 - name: Build WASM run: ./tool/build_wasm.sh diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index fa13aab..5ac33a3 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -15,7 +15,7 @@ jobs: - name: Install Rust Nightly uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly-2024-05-18 + toolchain: nightly-2025-04-15 components: rust-src - name: Build binary diff --git a/Cargo.lock b/Cargo.lock index 0f3c7d5..fdacb13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -34,7 +34,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash", + "rustc-hash 1.1.0", "shlex", "syn 2.0.100", "which", @@ -113,6 +113,22 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-lite" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5edaec856126859abb19ed65f39e90fea3a9574b9707f13539acf4abf7eb532" +dependencies = [ + "futures-core", + "pin-project-lite", +] + [[package]] name = "getrandom" version = "0.2.10" @@ -235,14 +251,22 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + [[package]] name = "powersync_core" version = "0.3.14" dependencies = [ "bytes", "const_format", + "futures-lite", "num-derive 0.3.3", "num-traits", + "rustc-hash 2.1.1", "serde", "serde_json", "sqlite_nostd", @@ -330,6 +354,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "ryu" version = "1.0.15" diff --git a/Cargo.toml b/Cargo.toml index 3506783..eb7eb75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,6 @@ default-members = ["crates/shell", "crates/sqlite"] [profile.dev] panic = "abort" -strip = true [profile.release] panic = "abort" diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 9aef67e..3839d97 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -18,9 +18,11 @@ bytes = { version = "1.4", default-features = false } num-traits = { version = "0.2.15", default-features = false } num-derive = "0.3" serde_json = { version = "1.0", default-features = false, features = ["alloc"] } -serde = { version = "1.0", default-features = false, features = ["alloc", "derive"] } -streaming-iterator = { version = "0.1.9", default-features = false, features = ["alloc"] } +serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] } const_format = "0.2.34" +futures-lite = { version = "2.6.0", default-features = false, features = ["alloc"] } +rustc-hash = { version = "2.1", default-features = false } +streaming-iterator = { version = "0.1.9", default-features = false, features = ["alloc"] } [dependencies.uuid] version = "1.4.1" diff --git a/crates/core/src/bson/de.rs b/crates/core/src/bson/de.rs new file mode 100644 index 0000000..d6d3e01 --- /dev/null +++ b/crates/core/src/bson/de.rs @@ -0,0 +1,322 @@ +use core::assert_matches::debug_assert_matches; + +use serde::{ + de::{ + self, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess, + Visitor, + }, + forward_to_deserialize_any, +}; + +use super::{ + error::ErrorKind, + parser::{ElementType, Parser}, + BsonError, +}; + +pub struct Deserializer<'de> { + parser: Parser<'de>, + position: DeserializerPosition, +} + +#[derive(Clone, Debug)] +enum DeserializerPosition { + /// The deserializer is outside of the initial document header. + OutsideOfDocument, + /// The deserializer expects the beginning of a key-value pair, or the end of the current + /// document. + BeforeTypeOrAtEndOfDocument, + /// The deserializer has read past the type of a key-value pair, but did not scan the name yet. + BeforeName { pending_type: ElementType }, + /// Read type and name of a key-value pair, position is before the value now. + BeforeValue { pending_type: ElementType }, +} + +impl<'de> Deserializer<'de> { + /// When used as a name hint to [de::Deserialize.deserialize_enum], the BSON deserializer will + /// report documents a byte array view instead of parsing them. + /// + /// This is used as an internal optimization when we want to keep a reference to a BSON sub- + /// document without actually inspecting the structure of that document. + pub const SPECIAL_CASE_EMBEDDED_DOCUMENT: &'static str = "\0SpecialCaseEmbedDoc"; + + fn outside_of_document(parser: Parser<'de>) -> Self { + Self { + parser, + position: DeserializerPosition::OutsideOfDocument, + } + } + + pub fn from_bytes(bytes: &'de [u8]) -> Self { + let parser = Parser::new(bytes); + Self::outside_of_document(parser) + } + + fn prepare_to_read(&mut self, allow_key: bool) -> Result, BsonError> { + match self.position.clone() { + DeserializerPosition::OutsideOfDocument => { + // The next value we're reading is a document + self.position = DeserializerPosition::BeforeValue { + pending_type: ElementType::Document, + }; + Ok(KeyOrValue::PendingValue(ElementType::Document)) + } + DeserializerPosition::BeforeValue { pending_type } => { + Ok(KeyOrValue::PendingValue(pending_type)) + } + DeserializerPosition::BeforeTypeOrAtEndOfDocument { .. } => { + Err(self.parser.error(ErrorKind::InvalidStateExpectedType)) + } + DeserializerPosition::BeforeName { pending_type } => { + if !allow_key { + return Err(self.parser.error(ErrorKind::InvalidStateExpectedName)); + } + + self.position = DeserializerPosition::BeforeValue { + pending_type: pending_type, + }; + Ok(KeyOrValue::Key(self.parser.read_cstr()?)) + } + } + } + + fn prepare_to_read_value(&mut self) -> Result { + let result = self.prepare_to_read(false)?; + match result { + KeyOrValue::Key(_) => unreachable!(), + KeyOrValue::PendingValue(element_type) => Ok(element_type), + } + } + + fn object_reader(&mut self) -> Result, BsonError> { + let parser = self.parser.document_scope()?; + let deserializer = Deserializer { + parser, + position: DeserializerPosition::BeforeTypeOrAtEndOfDocument, + }; + Ok(deserializer) + } + + fn advance_to_next_name(&mut self) -> Result, BsonError> { + if self.parser.end_document()? { + return Ok(None); + } + + self.position = DeserializerPosition::BeforeName { + pending_type: self.parser.read_element_type()?, + }; + Ok(Some(())) + } +} + +impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { + type Error = BsonError; + + fn is_human_readable(&self) -> bool { + false + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let element_type = match self.prepare_to_read(true)? { + KeyOrValue::Key(name) => return visitor.visit_borrowed_str(name), + KeyOrValue::PendingValue(element_type) => element_type, + }; + + match element_type { + ElementType::Double => visitor.visit_f64(self.parser.read_double()?), + ElementType::String => visitor.visit_borrowed_str(self.parser.read_string()?), + ElementType::Document => { + let mut object = self.object_reader()?; + visitor.visit_map(&mut object) + } + ElementType::Array => { + let mut object = self.object_reader()?; + visitor.visit_seq(&mut object) + } + ElementType::Binary => { + let (_, bytes) = self.parser.read_binary()?; + visitor.visit_borrowed_bytes(bytes) + } + ElementType::ObjectId => visitor.visit_borrowed_bytes(self.parser.read_object_id()?), + ElementType::Boolean => visitor.visit_bool(self.parser.read_bool()?), + ElementType::DatetimeUtc | ElementType::Timestamp => { + visitor.visit_u64(self.parser.read_uint64()?) + } + ElementType::Null | ElementType::Undefined => visitor.visit_unit(), + ElementType::Int32 => visitor.visit_i32(self.parser.read_int32()?), + ElementType::Int64 => visitor.visit_i64(self.parser.read_int64()?), + } + } + + fn deserialize_enum( + self, + name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + let kind = self.prepare_to_read_value()?; + + // With this special name, the visitor indicates that it doesn't actually want to read an + // enum, it wants to read values regularly. Except that a document appearing at this + // position should not be parsed, it should be forwarded as an embedded byte array. + if name == Deserializer::SPECIAL_CASE_EMBEDDED_DOCUMENT { + return if matches!(kind, ElementType::Document) { + let object = self.parser.skip_document()?; + visitor.visit_borrowed_bytes(object) + } else { + self.deserialize_any(visitor) + }; + } + + match kind { + ElementType::String => { + visitor.visit_enum(self.parser.read_string()?.into_deserializer()) + } + ElementType::Document => { + let mut object = self.object_reader()?; + visitor.visit_enum(&mut object) + } + _ => Err(self.parser.error(ErrorKind::ExpectedEnum { actual: kind })), + } + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let kind = self.prepare_to_read_value()?; + match kind { + ElementType::Null => visitor.visit_none(), + _ => visitor.visit_some(self), + } + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.prepare_to_read_value()?; + visitor.visit_newtype_struct(self) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf unit unit_struct seq tuple + tuple_struct map struct ignored_any identifier + } +} + +impl<'de> MapAccess<'de> for Deserializer<'de> { + type Error = BsonError; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + if let None = self.advance_to_next_name()? { + return Ok(None); + } + Ok(Some(seed.deserialize(self)?)) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + seed.deserialize(self) + } +} + +impl<'de> SeqAccess<'de> for Deserializer<'de> { + type Error = BsonError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + // Array elements are encoded as an object like `{"0": value, "1": another}` + if let None = self.advance_to_next_name()? { + return Ok(None); + } + + // Skip name + debug_assert_matches!(self.position, DeserializerPosition::BeforeName { .. }); + self.prepare_to_read(true)?; + + // And deserialize value! + Ok(Some(seed.deserialize(self)?)) + } +} + +impl<'a, 'de> EnumAccess<'de> for &'a mut Deserializer<'de> { + type Error = BsonError; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: DeserializeSeed<'de>, + { + if let None = self.advance_to_next_name()? { + return Err(self + .parser + .error(ErrorKind::UnexpectedEndOfDocumentForEnumVariant)); + } + + let value = seed.deserialize(&mut *self)?; + Ok((value, self)) + } +} + +impl<'a, 'de> VariantAccess<'de> for &'a mut Deserializer<'de> { + type Error = BsonError; + + fn unit_variant(self) -> Result<(), Self::Error> { + // Unit variants are encoded as simple string values, which are handled directly in + // Deserializer::deserialize_enum. + Err(self.parser.error(ErrorKind::ExpectedString)) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + // Newtype variants are represented as `{ NAME: VALUE }`, so we just have to deserialize the + // value here. + seed.deserialize(self) + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + // Tuple variants are represented as `{ NAME: VALUES[] }`, so we deserialize the array here. + de::Deserializer::deserialize_seq(self, visitor) + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + // Struct variants are represented as `{ NAME: { ... } }`, so we deserialize the struct. + de::Deserializer::deserialize_map(self, visitor) + } +} + +enum KeyOrValue<'de> { + Key(&'de str), + PendingValue(ElementType), +} diff --git a/crates/core/src/bson/error.rs b/crates/core/src/bson/error.rs new file mode 100644 index 0000000..dafeff3 --- /dev/null +++ b/crates/core/src/bson/error.rs @@ -0,0 +1,98 @@ +use core::{fmt::Display, str::Utf8Error}; + +use alloc::{ + boxed::Box, + string::{String, ToString}, +}; +use serde::de::{self, StdError}; + +use super::parser::ElementType; + +#[derive(Debug)] +pub struct BsonError { + /// Using a [Box] here keeps the size of this type as small, which makes results of this error + /// type smaller (at the cost of making errors more expensive to report, but that's fine because + /// we expect them to be rare). + err: Box, +} + +#[derive(Debug)] +struct BsonErrorImpl { + offset: Option, + kind: ErrorKind, +} + +#[derive(Debug)] +pub enum ErrorKind { + Custom(String), + UnknownElementType(i8), + UnterminatedCString, + InvalidCString(Utf8Error), + UnexpectedEoF, + InvalidEndOfDocument, + InvalidSize, + InvalidStateExpectedType, + InvalidStateExpectedName, + InvalidStateExpectedValue, + ExpectedEnum { actual: ElementType }, + ExpectedString, + UnexpectedEndOfDocumentForEnumVariant, +} + +impl BsonError { + pub fn new(offset: Option, kind: ErrorKind) -> Self { + Self { + err: Box::new(BsonErrorImpl { offset, kind }), + } + } +} + +impl Display for BsonError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.err.fmt(f) + } +} + +impl Display for BsonErrorImpl { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + if let Some(offset) = self.offset { + write!(f, "bson error, at {offset}: {}", self.kind) + } else { + write!(f, "bson error at unknown offset: {}", self.kind) + } + } +} + +impl Display for ErrorKind { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + ErrorKind::Custom(msg) => write!(f, "custom {msg}"), + ErrorKind::UnknownElementType(code) => write!(f, "unknown element code: {code}"), + ErrorKind::UnterminatedCString => write!(f, "unterminated cstring"), + ErrorKind::InvalidCString(e) => write!(f, "cstring with non-utf8 content: {e}"), + ErrorKind::UnexpectedEoF => write!(f, "unexpected end of file"), + ErrorKind::InvalidEndOfDocument => write!(f, "unexpected end of document"), + ErrorKind::InvalidSize => write!(f, "invalid document size"), + ErrorKind::InvalidStateExpectedType => write!(f, "internal state error, expected type"), + ErrorKind::InvalidStateExpectedName => write!(f, "internal state error, expected name"), + ErrorKind::InvalidStateExpectedValue => { + write!(f, "internal state error, expected value") + } + ErrorKind::ExpectedEnum { actual } => write!(f, "expected enum, got {}", *actual as u8), + ErrorKind::ExpectedString => write!(f, "expected a string value"), + ErrorKind::UnexpectedEndOfDocumentForEnumVariant => { + write!(f, "unexpected end of document for enum variant") + } + } + } +} + +impl de::Error for BsonError { + fn custom(msg: T) -> Self + where + T: Display, + { + BsonError::new(None, ErrorKind::Custom(msg.to_string())) + } +} +impl StdError for BsonError {} diff --git a/crates/core/src/bson/mod.rs b/crates/core/src/bson/mod.rs new file mode 100644 index 0000000..7d27620 --- /dev/null +++ b/crates/core/src/bson/mod.rs @@ -0,0 +1,371 @@ +pub use de::Deserializer; +pub use error::BsonError; +use serde::Deserialize; + +mod de; +mod error; +mod parser; + +/// Deserializes BSON [bytes] into a structure [T]. +pub fn from_bytes<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result { + let mut deserializer = Deserializer::from_bytes(bytes); + + T::deserialize(&mut deserializer) +} + +#[cfg(test)] +mod test { + use alloc::{vec, vec::Vec}; + use core::assert_matches::assert_matches; + + use crate::sync::line::{SyncLine, TokenExpiresIn}; + + use super::*; + + #[test] + fn test_hello_world() { + // {"hello": "world"} + let bson = b"\x16\x00\x00\x00\x02hello\x00\x06\x00\x00\x00world\x00\x00"; + + #[derive(Deserialize)] + struct Expected<'a> { + hello: &'a str, + } + + let expected: Expected = from_bytes(bson.as_slice()).expect("should deserialize"); + assert_eq!(expected.hello, "world"); + } + + #[test] + fn test_checkpoint_line() { + let bson = b"\x85\x00\x00\x00\x03checkpoint\x00t\x00\x00\x00\x02last_op_id\x00\x02\x00\x00\x001\x00\x0awrite_checkpoint\x00\x04buckets\x00B\x00\x00\x00\x030\x00:\x00\x00\x00\x02bucket\x00\x02\x00\x00\x00a\x00\x10checksum\x00\x00\x00\x00\x00\x10priority\x00\x03\x00\x00\x00\x10count\x00\x01\x00\x00\x00\x00\x00\x00\x00"; + + let expected: SyncLine = from_bytes(bson.as_slice()).expect("should deserialize"); + let SyncLine::Checkpoint(checkpoint) = expected else { + panic!("Expected to deserialize as checkpoint line") + }; + + assert_eq!(checkpoint.buckets.len(), 1); + } + + #[test] + fn test_newtype_tuple() { + let bson = b"\x1b\x00\x00\x00\x10token_expires_in\x00<\x00\x00\x00\x00"; + + let expected: SyncLine = from_bytes(bson.as_slice()).expect("should deserialize"); + assert_matches!(expected, SyncLine::KeepAlive(TokenExpiresIn(60))); + } + + #[test] + fn test_int64_positive_max() { + // {"value": 9223372036854775807} (i64::MAX) + let bson = b"\x14\x00\x00\x00\x12value\x00\xff\xff\xff\xff\xff\xff\xff\x7f\x00"; + + #[derive(Deserialize)] + struct TestDoc { + value: i64, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.value, i64::MAX); + } + + #[test] + fn test_int64_negative_max() { + // {"value": -9223372036854775808} (i64::MIN) + let bson = b"\x14\x00\x00\x00\x12value\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00"; + + #[derive(Deserialize)] + struct TestDoc { + value: i64, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.value, i64::MIN); + } + + #[test] + fn test_int64_negative_one() { + // {"value": -1} + let bson = b"\x14\x00\x00\x00\x12value\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00"; + + #[derive(Deserialize)] + struct TestDoc { + value: i64, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.value, -1); + } + + #[test] + fn test_int64_negative_small() { + // {"value": -42} + let bson = b"\x14\x00\x00\x00\x12value\x00\xd6\xff\xff\xff\xff\xff\xff\xff\x00"; + + #[derive(Deserialize)] + struct TestDoc { + value: i64, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.value, -42); + } + + #[test] + fn test_int32_negative_values() { + // {"small": -1, "large": -2147483648} (i32::MIN) + let bson = + b"\x1b\x00\x00\x00\x10small\x00\xff\xff\xff\xff\x10large\x00\x00\x00\x00\x80\x00"; + + #[derive(Deserialize)] + struct TestDoc { + small: i32, + large: i32, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.small, -1); + assert_eq!(doc.large, i32::MIN); + } + + #[test] + fn test_double_negative_values() { + // {"neg": -3.14159} + let bson = b"\x12\x00\x00\x00\x01neg\x00\x6e\x86\x1b\xf0\xf9\x21\x09\xc0\x00"; + + #[derive(Deserialize)] + struct TestDoc { + neg: f64, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert!((doc.neg - (-3.14159)).abs() < 0.00001); + } + + #[test] + fn test_double_special_values() { + // Test infinity, negative infinity, and NaN representations + // {"inf": Infinity, "ninf": -Infinity, "nan": NaN} + let bson = b"\x2d\x00\x00\x00\x01\x69\x6e\x66\x00\x00\x00\x00\x00\x00\x00\xf0\x7f\x01\x6e\x69\x6e\x66\x00\x00\x00\x00\x00\x00\x00\xf0\xff\x01\x6e\x61\x6e\x00\x00\x00\x00\x00\x00\x00\xf8\x7f\x00"; + + #[derive(Deserialize)] + struct TestDoc { + inf: f64, + ninf: f64, + nan: f64, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.inf, f64::INFINITY); + assert_eq!(doc.ninf, f64::NEG_INFINITY); + assert!(doc.nan.is_nan()); + } + + #[test] + fn test_empty_string() { + // {"empty": ""} + let bson = b"\x11\x00\x00\x00\x02empty\x00\x01\x00\x00\x00\x00\x00"; + + #[derive(Deserialize)] + struct TestDoc<'a> { + empty: &'a str, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.empty, ""); + } + + #[test] + fn test_unicode_string() { + // {"unicode": "🦀💖"} + let bson = b"\x1b\x00\x00\x00\x02unicode\x00\x09\x00\x00\x00\xf0\x9f\xa6\x80\xf0\x9f\x92\x96\x00\x00"; + + #[derive(Deserialize)] + struct TestDoc<'a> { + unicode: &'a str, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.unicode, "🦀💖"); + } + + #[test] + fn test_boolean_values() { + // {"true_val": true, "false_val": false} + let bson = b"\x1c\x00\x00\x00\x08true_val\x00\x01\x08false_val\x00\x00\x00"; + + #[derive(Deserialize)] + struct TestDoc { + true_val: bool, + false_val: bool, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.true_val, true); + assert_eq!(doc.false_val, false); + } + + #[test] + fn test_null_value() { + // {"null_val": null} + let bson = b"\x0f\x00\x00\x00\x0anull_val\x00\x00"; + + #[derive(Deserialize)] + struct TestDoc { + null_val: Option, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.null_val, None); + } + + #[test] + fn test_empty_document() { + // {} + let bson = b"\x05\x00\x00\x00\x00"; + + #[derive(Deserialize)] + struct TestDoc {} + + let _doc: TestDoc = from_bytes(bson).expect("should deserialize"); + } + + #[test] + fn test_nested_document() { + // {"nested": {"inner": 42}} + let bson = + b"\x1d\x00\x00\x00\x03nested\x00\x10\x00\x00\x00\x10inner\x00*\x00\x00\x00\x00\x00"; + + #[derive(Deserialize)] + struct Inner { + inner: i32, + } + + #[derive(Deserialize)] + struct TestDoc { + nested: Inner, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.nested.inner, 42); + } + + #[test] + fn test_array_with_integers() { + // {"array": [1, 2]} - simplified array test + // Array format: {"0": 1, "1": 2} + let bson = b"\x1f\x00\x00\x00\x04array\x00\x13\x00\x00\x00\x100\x00\x01\x00\x00\x00\x101\x00\x02\x00\x00\x00\x00\x00"; + + #[derive(Deserialize)] + struct TestDoc { + array: Vec, + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.array, vec![1, 2]); + } + + #[test] + fn test_binary_data() { + // {"binary": } + let bson = b"\x16\x00\x00\x00\x05binary\x00\x04\x00\x00\x00\x00\x01\x02\x03\x04\x00"; + + #[derive(Deserialize)] + struct TestDoc<'a> { + binary: &'a [u8], + } + + let doc: TestDoc = from_bytes(bson).expect("should deserialize"); + assert_eq!(doc.binary, &[1, 2, 3, 4]); + } + + // Error case tests + + #[test] + fn test_invalid_element_type() { + // Document with invalid element type (99) + let bson = b"\x10\x00\x00\x00\x63test\x00\x01\x00\x00\x00\x00"; + + #[derive(Deserialize)] + #[allow(dead_code)] + struct TestDoc { + test: i32, + } + + let result: Result = from_bytes(bson); + assert!(result.is_err()); + } + + #[test] + fn test_truncated_document() { + // Document claims to be longer than actual data + let bson = b"\xff\x00\x00\x00\x10test\x00"; + + #[derive(Deserialize)] + #[allow(dead_code)] + struct TestDoc { + test: i32, + } + + let result: Result = from_bytes(bson); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_string_length() { + // String with invalid length + let bson = b"\x15\x00\x00\x00\x02test\x00\xff\xff\xff\xff\x00"; + + #[derive(Deserialize)] + #[allow(dead_code)] + struct TestDoc<'a> { + test: &'a str, + } + + let result: Result = from_bytes(bson); + assert!(result.is_err()); + } + + #[test] + fn test_unterminated_cstring() { + // Document with field name that doesn't have null terminator + let bson = b"\x10\x00\x00\x00\x10test\x01\x00\x00\x00\x00\x00"; + + #[derive(Deserialize)] + #[allow(dead_code)] + struct TestDoc { + test: i32, + } + + let result: Result = from_bytes(bson); + assert!(result.is_err()); + } + + #[test] + fn test_document_without_terminator() { + // Document missing the final null byte + let bson = b"\x0d\x00\x00\x00\x10test\x00*\x00\x00\x00"; + + #[derive(Deserialize)] + #[allow(dead_code)] + struct TestDoc { + test: i32, + } + + let result: Result = from_bytes(bson); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_document_size() { + // Document with size less than minimum (5 bytes) + let bson = b"\x04\x00\x00\x00\x00"; + + #[derive(Deserialize)] + struct TestDoc {} + + let result: Result = from_bytes(bson); + assert!(result.is_err()); + } +} diff --git a/crates/core/src/bson/parser.rs b/crates/core/src/bson/parser.rs new file mode 100644 index 0000000..ddd31b1 --- /dev/null +++ b/crates/core/src/bson/parser.rs @@ -0,0 +1,523 @@ +use core::ffi::CStr; + +use super::{error::ErrorKind, BsonError}; +use num_traits::{FromBytes, Num}; + +pub struct Parser<'de> { + offset: usize, + remaining_input: &'de [u8], +} + +impl<'de> Parser<'de> { + pub fn new(source: &'de [u8]) -> Self { + Self { + offset: 0, + remaining_input: source, + } + } + + #[cold] + pub fn error(&self, kind: ErrorKind) -> BsonError { + BsonError::new(Some(self.offset), kind) + } + + /// Advances the position of the parser, panicking on bound errors. + fn advance(&mut self, by: usize) { + self.offset = self.offset.strict_add(by); + self.remaining_input = &self.remaining_input[by..]; + } + + /// Reads a sized buffer from the parser and advances the input accordingly. + /// + /// This returns an error if not enough bytes are left in the input. + fn advance_checked(&mut self, size: usize) -> Result<&'de [u8], BsonError> { + let (taken, rest) = self + .remaining_input + .split_at_checked(size) + .ok_or_else(|| self.error(ErrorKind::UnexpectedEoF))?; + + self.offset += size; + self.remaining_input = rest; + Ok(taken) + } + + fn advance_byte(&mut self) -> Result { + let value = *self + .remaining_input + .split_off_first() + .ok_or_else(|| self.error(ErrorKind::UnexpectedEoF))?; + + Ok(value) + } + + fn advance_bytes(&mut self) -> Result<&'de [u8; N], BsonError> { + let bytes = self.advance_checked(N)?; + Ok(bytes.try_into().expect("should have correct length")) + } + + pub fn read_cstr(&mut self) -> Result<&'de str, BsonError> { + let raw = CStr::from_bytes_until_nul(self.remaining_input) + .map_err(|_| self.error(ErrorKind::UnterminatedCString))?; + let str = raw + .to_str() + .map_err(|e| self.error(ErrorKind::InvalidCString(e)))?; + + self.advance(str.len() + 1); + Ok(str) + } + + fn read_number>( + &mut self, + ) -> Result { + let bytes = self.advance_bytes::()?; + Ok(T::from_le_bytes(&bytes)) + } + + pub fn read_int32(&mut self) -> Result { + self.read_number() + } + + fn read_length(&mut self) -> Result { + let raw = self.read_int32()?; + u32::try_from(raw) + .and_then(usize::try_from) + .map_err(|_| self.error(ErrorKind::InvalidSize)) + } + + pub fn read_int64(&mut self) -> Result { + self.read_number() + } + + pub fn read_uint64(&mut self) -> Result { + self.read_number() + } + + pub fn read_double(&mut self) -> Result { + self.read_number() + } + + pub fn read_bool(&mut self) -> Result { + let byte = self.advance_byte()?; + Ok(byte != 0) + } + + pub fn read_object_id(&mut self) -> Result<&'de [u8], BsonError> { + self.advance_checked(12) + } + + /// Reads a BSON string, `string ::= int32 (byte*) unsigned_byte(0)` + pub fn read_string(&mut self) -> Result<&'de str, BsonError> { + let length_including_null = self.read_length()?; + let bytes = self.advance_checked(length_including_null)?; + + str::from_utf8(&bytes[..length_including_null - 1]) + .map_err(|e| self.error(ErrorKind::InvalidCString(e))) + } + + pub fn read_binary(&mut self) -> Result<(BinarySubtype, &'de [u8]), BsonError> { + let length = self.read_length()?; + let subtype = self.advance_byte()?; + let binary = self.advance_checked(length)?; + + Ok((BinarySubtype(subtype), binary)) + } + + pub fn read_element_type(&mut self) -> Result { + let raw_type = self.advance_byte()? as i8; + Ok(match raw_type { + 1 => ElementType::Double, + 2 => ElementType::String, + 3 => ElementType::Document, + 4 => ElementType::Array, + 5 => ElementType::Binary, + 6 => ElementType::Undefined, + 7 => ElementType::ObjectId, + 8 => ElementType::Boolean, + 9 => ElementType::DatetimeUtc, + 10 => ElementType::Null, + 16 => ElementType::Int32, + 17 => ElementType::Timestamp, + 18 => ElementType::Int64, + _ => return Err(self.error(ErrorKind::UnknownElementType(raw_type))), + }) + } + + fn subreader(&mut self, len: usize) -> Result, BsonError> { + let current_offset = self.offset; + let for_sub_reader = self.advance_checked(len)?; + Ok(Parser { + offset: current_offset, + remaining_input: for_sub_reader, + }) + } + + /// Reads a document header and skips over the contents of the document. + /// + /// Returns a new [Parser] that can only read contents of the document. + pub fn document_scope(&mut self) -> Result, BsonError> { + let total_size = self.read_length()?; + if total_size < 5 { + return Err(self.error(ErrorKind::InvalidSize))?; + } + + self.subreader(total_size - 4) + } + + /// Skips over a document at the current offset, returning the bytes making up the document. + pub fn skip_document(&mut self) -> Result<&'de [u8], BsonError> { + let Some(peek_size) = self.remaining_input.get(0..4) else { + return Err(self.error(ErrorKind::UnexpectedEoF)); + }; + + let parsed_size = u32::try_from(i32::from_le_bytes( + peek_size.try_into().expect("should have correct length"), + )) + .and_then(usize::try_from) + .map_err(|_| self.error(ErrorKind::InvalidSize))?; + + if parsed_size < 5 || parsed_size >= self.remaining_input.len() { + return Err(self.error(ErrorKind::InvalidSize))?; + } + + Ok(self.subreader(parsed_size)?.remaining()) + } + + /// If only a single byte is left in the current scope, validate that it is a zero byte. + /// + /// Otherwise returns false as we haven't reached the end of a document. + pub fn end_document(&mut self) -> Result { + Ok(if self.remaining_input.len() == 1 { + let trailing_zero = self.advance_byte()?; + if trailing_zero != 0 { + return Err(self.error(ErrorKind::InvalidEndOfDocument)); + } + + true + } else { + false + }) + } + + pub fn remaining(&self) -> &'de [u8] { + self.remaining_input + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_read_int64_negative_values() { + let neg_one_bytes = (-1i64).to_le_bytes(); + let mut parser = Parser::new(&neg_one_bytes); + assert_eq!(parser.read_int64().unwrap(), -1); + + let min_bytes = (i64::MIN).to_le_bytes(); + let mut parser = Parser::new(&min_bytes); + assert_eq!(parser.read_int64().unwrap(), i64::MIN); + + let neg_42_bytes = (-42i64).to_le_bytes(); + let mut parser = Parser::new(&neg_42_bytes); + assert_eq!(parser.read_int64().unwrap(), -42); + } + + #[test] + fn test_read_int32_negative_values() { + let neg_one_bytes = (-1i32).to_le_bytes(); + let mut parser = Parser::new(&neg_one_bytes); + assert_eq!(parser.read_int32().unwrap(), -1); + + let min_bytes = (i32::MIN).to_le_bytes(); + let mut parser = Parser::new(&min_bytes); + assert_eq!(parser.read_int32().unwrap(), i32::MIN); + + let neg_42_bytes = (-42i32).to_le_bytes(); + let mut parser = Parser::new(&neg_42_bytes); + assert_eq!(parser.read_int32().unwrap(), -42); + } + + #[test] + fn test_read_double_negative_and_special() { + let neg_pi_bytes = (-3.14159f64).to_le_bytes(); + let mut parser = Parser::new(&neg_pi_bytes); + let val = parser.read_double().unwrap(); + assert!((val - (-3.14159)).abs() < 0.00001); + + let neg_inf_bytes = f64::NEG_INFINITY.to_le_bytes(); + let mut parser = Parser::new(&neg_inf_bytes); + assert_eq!(parser.read_double().unwrap(), f64::NEG_INFINITY); + + let nan_bytes = f64::NAN.to_le_bytes(); + let mut parser = Parser::new(&nan_bytes); + assert!(parser.read_double().unwrap().is_nan()); + } + + #[test] + fn test_read_bool_edge_cases() { + let mut parser = Parser::new(&[0x00]); + assert_eq!(parser.read_bool().unwrap(), false); + + let mut parser = Parser::new(&[0x01]); + assert_eq!(parser.read_bool().unwrap(), true); + + let mut parser = Parser::new(&[0xFF]); + assert_eq!(parser.read_bool().unwrap(), true); + + let mut parser = Parser::new(&[0x7F]); + assert_eq!(parser.read_bool().unwrap(), true); + } + + #[test] + fn test_read_string_empty() { + // Empty string: length=1, content=null terminator + let data = &[0x01, 0x00, 0x00, 0x00, 0x00]; + let mut parser = Parser::new(data); + assert_eq!(parser.read_string().unwrap(), ""); + } + + #[test] + fn test_read_string_unicode() { + // String "🦀" (4 UTF-8 bytes + null terminator) + let data = &[0x05, 0x00, 0x00, 0x00, 0xf0, 0x9f, 0xa6, 0x80, 0x00]; + let mut parser = Parser::new(data); + assert_eq!(parser.read_string().unwrap(), "🦀"); + } + + #[test] + fn test_read_cstr_empty() { + let data = &[0x00]; + let mut parser = Parser::new(data); + assert_eq!(parser.read_cstr().unwrap(), ""); + } + + #[test] + fn test_read_cstr_unicode() { + let data = &[0xf0, 0x9f, 0xa6, 0x80, 0x00]; // "🦀\0" + let mut parser = Parser::new(data); + assert_eq!(parser.read_cstr().unwrap(), "🦀"); + } + + #[test] + fn test_element_type_all_valid() { + let valid_types = [ + (1, ElementType::Double), + (2, ElementType::String), + (3, ElementType::Document), + (4, ElementType::Array), + (5, ElementType::Binary), + (6, ElementType::Undefined), + (7, ElementType::ObjectId), + (8, ElementType::Boolean), + (9, ElementType::DatetimeUtc), + (10, ElementType::Null), + (16, ElementType::Int32), + (17, ElementType::Timestamp), + (18, ElementType::Int64), + ]; + + for (byte, expected) in valid_types { + let data = [byte]; + let mut parser = Parser::new(&data); + let result = parser.read_element_type().unwrap(); + assert_eq!(result as u8, expected as u8); + } + } + + #[test] + fn test_element_type_invalid() { + let invalid_types = [0, 11, 12, 13, 14, 15, 19, 20, 99, 255]; + + for invalid_type in invalid_types { + let data = [invalid_type]; + let mut parser = Parser::new(&data); + let result = parser.read_element_type(); + assert!(result.is_err()); + } + } + + #[test] + fn test_document_scope_minimum_size() { + // Minimum valid document: 5 bytes total + let data = &[0x05, 0x00, 0x00, 0x00, 0x00]; + let mut parser = Parser::new(data); + let sub_parser = parser.document_scope().unwrap(); + assert_eq!(sub_parser.remaining().len(), 1); // Just the terminator + } + + #[test] + fn test_document_scope_invalid_size() { + // Document claiming size < 5 + let data = &[0x04, 0x00, 0x00, 0x00]; + let mut parser = Parser::new(data); + assert!(parser.document_scope().is_err()); + } + + #[test] + fn test_binary_data_empty() { + // Binary with length 0, subtype 0 + let data = &[0x00, 0x00, 0x00, 0x00, 0x00]; + let mut parser = Parser::new(data); + let (subtype, binary) = parser.read_binary().unwrap(); + assert_eq!(subtype.0, 0); + assert_eq!(binary.len(), 0); + } + + #[test] + fn test_binary_data_with_content() { + // Binary with length 3, subtype 5, content [1,2,3] + let data = &[0x03, 0x00, 0x00, 0x00, 0x05, 0x01, 0x02, 0x03]; + let mut parser = Parser::new(data); + let (subtype, binary) = parser.read_binary().unwrap(); + assert_eq!(subtype.0, 5); + assert_eq!(binary, &[1, 2, 3]); + } + + #[test] + fn test_object_id_exact_size() { + let data = &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c]; + let mut parser = Parser::new(data); + let oid = parser.read_object_id().unwrap(); + assert_eq!(oid, data); + } + + #[test] + fn test_advance_checked_boundary() { + let data = &[0x01, 0x02, 0x03]; + let mut parser = Parser::new(data); + + // Should succeed + assert!(parser.advance_checked(3).is_ok()); + assert_eq!(parser.remaining().len(), 0); + + // Should fail - no more data + assert!(parser.advance_checked(1).is_err()); + } + + #[test] + fn test_end_document_valid() { + let data = &[0x00]; + let mut parser = Parser::new(data); + assert_eq!(parser.end_document().unwrap(), true); + assert_eq!(parser.remaining().len(), 0); + } + + #[test] + fn test_end_document_invalid_terminator() { + let data = &[0x01]; + let mut parser = Parser::new(data); + assert!(parser.end_document().is_err()); + } + + #[test] + fn test_end_document_not_at_end() { + let data = &[0x01, 0x02, 0x03]; + let mut parser = Parser::new(data); + assert_eq!(parser.end_document().unwrap(), false); + } + + // Error boundary tests + + #[test] + fn test_unexpected_eof_int32() { + let data = &[0x01, 0x02]; // Only 2 bytes, need 4 + let mut parser = Parser::new(data); + assert!(parser.read_int32().is_err()); + } + + #[test] + fn test_unexpected_eof_int64() { + let data = &[0x01, 0x02, 0x03, 0x04]; // Only 4 bytes, need 8 + let mut parser = Parser::new(data); + assert!(parser.read_int64().is_err()); + } + + #[test] + fn test_unexpected_eof_double() { + let data = &[0x01, 0x02, 0x03, 0x04]; // Only 4 bytes, need 8 + let mut parser = Parser::new(data); + assert!(parser.read_double().is_err()); + } + + #[test] + fn test_unexpected_eof_object_id() { + let data = &[0x01, 0x02, 0x03, 0x04]; // Only 4 bytes, need 12 + let mut parser = Parser::new(data); + assert!(parser.read_object_id().is_err()); + } + + #[test] + fn test_string_length_overflow() { + // Invalid negative length + let data = &[0xff, 0xff, 0xff, 0xff, 0x00]; + let mut parser = Parser::new(data); + assert!(parser.read_string().is_err()); + } + + #[test] + fn test_string_insufficient_data() { + // Claims length 10 but only has 5 bytes total + let data = &[0x0a, 0x00, 0x00, 0x00, 0x00]; + let mut parser = Parser::new(data); + assert!(parser.read_string().is_err()); + } + + #[test] + fn test_binary_length_overflow() { + // Invalid negative length + let data = &[0xff, 0xff, 0xff, 0xff, 0x00]; + let mut parser = Parser::new(data); + assert!(parser.read_binary().is_err()); + } + + #[test] + fn test_binary_insufficient_data() { + // Claims length 10 but only has 2 bytes after subtype + let data = &[0x0a, 0x00, 0x00, 0x00, 0x05, 0x01, 0x02]; + let mut parser = Parser::new(data); + assert!(parser.read_binary().is_err()); + } + + #[test] + fn test_cstr_unterminated() { + let data = &[0x48, 0x65, 0x6c, 0x6c, 0x6f]; // "Hello" without null terminator + let mut parser = Parser::new(data); + assert!(parser.read_cstr().is_err()); + } + + #[test] + fn test_invalid_utf8_string() { + // Invalid UTF-8 sequence in string + let data = &[0x05, 0x00, 0x00, 0x00, 0xff, 0xfe, 0xfd, 0xfc, 0x00]; + let mut parser = Parser::new(data); + assert!(parser.read_string().is_err()); + } + + #[test] + fn test_invalid_utf8_cstr() { + // Invalid UTF-8 sequence in cstring + let data = &[0xff, 0xfe, 0xfd, 0xfc, 0x00]; + let mut parser = Parser::new(data); + assert!(parser.read_cstr().is_err()); + } +} + +#[repr(transparent)] +pub struct BinarySubtype(pub u8); + +#[derive(Clone, Copy, Debug)] +pub enum ElementType { + Double = 1, + String = 2, + Document = 3, + Array = 4, + Binary = 5, + Undefined = 6, + ObjectId = 7, + Boolean = 8, + DatetimeUtc = 9, + Null = 10, + Int32 = 16, + Timestamp = 17, + Int64 = 18, +} diff --git a/crates/core/src/checkpoint.rs b/crates/core/src/checkpoint.rs index 1e6527d..7e96272 100644 --- a/crates/core/src/checkpoint.rs +++ b/crates/core/src/checkpoint.rs @@ -1,11 +1,10 @@ extern crate alloc; -use alloc::format; use alloc::string::String; use alloc::vec::Vec; use core::ffi::c_int; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use serde_json as json; use sqlite::ResultCode; use sqlite_nostd as sqlite; @@ -13,9 +12,10 @@ use sqlite_nostd::{Connection, Context, Value}; use crate::create_sqlite_text_fn; use crate::error::SQLiteError; -use crate::sync_types::Checkpoint; +use crate::sync::checkpoint::{validate_checkpoint, OwnedBucketChecksum}; +use crate::sync::line::Checkpoint; -#[derive(Serialize, Deserialize)] +#[derive(Serialize)] struct CheckpointResult { valid: bool, failed_buckets: Vec, @@ -26,53 +26,23 @@ fn powersync_validate_checkpoint_impl( args: &[*mut sqlite::value], ) -> Result { let data = args[0].text(); - - let _checkpoint: Checkpoint = serde_json::from_str(data)?; - + let checkpoint: Checkpoint = serde_json::from_str(data)?; let db = ctx.db_handle(); - - // language=SQLite - let statement = db.prepare_v2( - "WITH -bucket_list(bucket, checksum) AS ( - SELECT - json_extract(json_each.value, '$.bucket') as bucket, - json_extract(json_each.value, '$.checksum') as checksum - FROM json_each(json_extract(?1, '$.buckets')) -) -SELECT - bucket_list.bucket as bucket, - IFNULL(buckets.add_checksum, 0) as add_checksum, - IFNULL(buckets.op_checksum, 0) as oplog_checksum, - bucket_list.checksum as expected_checksum -FROM bucket_list - LEFT OUTER JOIN ps_buckets AS buckets ON - buckets.name = bucket_list.bucket -GROUP BY bucket_list.bucket", - )?; - - statement.bind_text(1, data, sqlite::Destructor::STATIC)?; - - let mut failures: Vec = alloc::vec![]; - - while statement.step()? == ResultCode::ROW { - let name = statement.column_text(0)?; - // checksums with column_int are wrapped to i32 by SQLite - let add_checksum = statement.column_int(1); - let oplog_checksum = statement.column_int(2); - let expected_checksum = statement.column_int(3); - - // wrapping add is like +, but safely overflows - let checksum = oplog_checksum.wrapping_add(add_checksum); - - if checksum != expected_checksum { - failures.push(String::from(name)); - } + let buckets: Vec = checkpoint + .buckets + .iter() + .map(OwnedBucketChecksum::from) + .collect(); + + let failures = validate_checkpoint(buckets.iter(), None, db)?; + let mut failed_buckets = Vec::::with_capacity(failures.len()); + for failure in failures { + failed_buckets.push(failure.bucket_name); } let result = CheckpointResult { - valid: failures.is_empty(), - failed_buckets: failures, + valid: failed_buckets.is_empty(), + failed_buckets: failed_buckets, }; Ok(json::to_string(&result)?) diff --git a/crates/core/src/diff.rs b/crates/core/src/diff.rs index 7e5ad18..fd37a05 100644 --- a/crates/core/src/diff.rs +++ b/crates/core/src/diff.rs @@ -1,6 +1,5 @@ extern crate alloc; -use alloc::format; use alloc::string::{String, ToString}; use core::ffi::c_int; diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 3aa5dcf..a2cf8dd 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -1,6 +1,11 @@ -use alloc::string::{String, ToString}; +use alloc::{ + format, + string::{String, ToString}, +}; use core::error::Error; -use sqlite_nostd::{sqlite3, Connection, ResultCode}; +use sqlite_nostd::{context, sqlite3, Connection, Context, ResultCode}; + +use crate::bson::BsonError; #[derive(Debug)] pub struct SQLiteError(pub ResultCode, pub Option); @@ -11,6 +16,24 @@ impl core::fmt::Display for SQLiteError { } } +impl SQLiteError { + pub fn apply_to_ctx(self, description: &str, ctx: *mut context) { + let SQLiteError(code, message) = self; + + if message.is_some() { + ctx.result_error(&format!("{:} {:}", description, message.unwrap())); + } else { + let error = ctx.db_handle().errmsg().unwrap(); + if error == "not an error" { + ctx.result_error(&format!("{:}", description)); + } else { + ctx.result_error(&format!("{:} {:}", description, error)); + } + } + ctx.result_error_code(code); + } +} + impl Error for SQLiteError {} pub trait PSResult { @@ -45,3 +68,15 @@ impl From for SQLiteError { SQLiteError(ResultCode::ABORT, Some(value.to_string())) } } + +impl From for SQLiteError { + fn from(value: core::fmt::Error) -> Self { + SQLiteError(ResultCode::INTERNAL, Some(format!("{}", value))) + } +} + +impl From for SQLiteError { + fn from(value: BsonError) -> Self { + SQLiteError(ResultCode::ERROR, Some(value.to_string())) + } +} diff --git a/crates/core/src/fix035.rs b/crates/core/src/fix035.rs deleted file mode 100644 index f90cb6c..0000000 --- a/crates/core/src/fix035.rs +++ /dev/null @@ -1,47 +0,0 @@ -use alloc::format; - -use crate::error::{PSResult, SQLiteError}; -use sqlite_nostd as sqlite; -use sqlite_nostd::{Connection, ResultCode}; - -use crate::ext::SafeManagedStmt; -use crate::util::quote_identifier; - -// Apply a data migration to fix any existing data affected by the issue -// fixed in v0.3.5. -// -// The issue was that the `ps_updated_rows` table was not being populated -// with remove operations in some cases. This causes the rows to be removed -// from ps_oplog, but not from the ps_data__tables, resulting in dangling rows. -// -// The fix here is to find these dangling rows, and add them to ps_updated_rows. -// The next time the sync_local operation is run, these rows will be removed. -pub fn apply_v035_fix(db: *mut sqlite::sqlite3) -> Result { - // language=SQLite - let statement = db - .prepare_v2("SELECT name, powersync_external_table_name(name) FROM sqlite_master WHERE type='table' AND name GLOB 'ps_data__*'") - .into_db_result(db)?; - - while statement.step()? == ResultCode::ROW { - let full_name = statement.column_text(0)?; - let short_name = statement.column_text(1)?; - let quoted = quote_identifier(full_name); - - // language=SQLite - let statement = db.prepare_v2(&format!( - " -INSERT OR IGNORE INTO ps_updated_rows(row_type, row_id) -SELECT ?1, id FROM {} - WHERE NOT EXISTS ( - SELECT 1 FROM ps_oplog - WHERE row_type = ?1 AND row_id = {}.id - );", - quoted, quoted - ))?; - statement.bind_text(1, short_name, sqlite::Destructor::STATIC)?; - - statement.exec()?; - } - - Ok(1) -} diff --git a/crates/core/src/fix_data.rs b/crates/core/src/fix_data.rs new file mode 100644 index 0000000..8dcab1b --- /dev/null +++ b/crates/core/src/fix_data.rs @@ -0,0 +1,200 @@ +use core::ffi::c_int; + +use alloc::format; +use alloc::string::String; + +use crate::create_sqlite_optional_text_fn; +use crate::error::{PSResult, SQLiteError}; +use sqlite_nostd::{self as sqlite, ColumnType, Value}; +use sqlite_nostd::{Connection, Context, ResultCode}; + +use crate::ext::SafeManagedStmt; +use crate::util::quote_identifier; + +// Apply a data migration to fix any existing data affected by the issue +// fixed in v0.3.5. +// +// The issue was that the `ps_updated_rows` table was not being populated +// with remove operations in some cases. This causes the rows to be removed +// from ps_oplog, but not from the ps_data__tables, resulting in dangling rows. +// +// The fix here is to find these dangling rows, and add them to ps_updated_rows. +// The next time the sync_local operation is run, these rows will be removed. +pub fn apply_v035_fix(db: *mut sqlite::sqlite3) -> Result { + // language=SQLite + let statement = db + .prepare_v2("SELECT name, powersync_external_table_name(name) FROM sqlite_master WHERE type='table' AND name GLOB 'ps_data__*'") + .into_db_result(db)?; + + while statement.step()? == ResultCode::ROW { + let full_name = statement.column_text(0)?; + let short_name = statement.column_text(1)?; + let quoted = quote_identifier(full_name); + + // language=SQLite + let statement = db.prepare_v2(&format!( + " +INSERT OR IGNORE INTO ps_updated_rows(row_type, row_id) +SELECT ?1, id FROM {} + WHERE NOT EXISTS ( + SELECT 1 FROM ps_oplog + WHERE row_type = ?1 AND row_id = {}.id + );", + quoted, quoted + ))?; + statement.bind_text(1, short_name, sqlite::Destructor::STATIC)?; + + statement.exec()?; + } + + Ok(1) +} + +/// Older versions of the JavaScript SDK for PowerSync used to encode the subkey in oplog data +/// entries as JSON. +/// +/// It wasn't supposed to do that, since the keys are regular strings already. To make databases +/// created with those SDKs compatible with other SDKs or the sync client implemented in the core +/// extensions, a migration is necessary. Since this migration is only relevant for the JS SDK, it +/// is mostly implemented there. However, the helper function to remove the key encoding is +/// implemented here because user-defined functions are expensive on JavaScript. +fn remove_duplicate_key_encoding(key: &str) -> Option { + // Acceptable format: // + // Inacceptable format: //"" + // This is a bit of a tricky conversion because both type and id can contain slashes and quotes. + // However, the subkey is either a UUID value or a `/UUID` value - so we know it can't + // end in a quote unless the improper encoding was used. + if !key.ends_with('"') { + return None; + } + + // Since the subkey is JSON-encoded, find the start quote by going backwards. + let mut chars = key.char_indices(); + chars.next_back()?; // Skip the quote ending the string + + enum FoundStartingQuote { + HasQuote { index: usize }, + HasBackslachThenQuote { quote_index: usize }, + } + let mut state: Option = None; + let found_starting_quote = loop { + if let Some((i, char)) = chars.next_back() { + state = match state { + Some(FoundStartingQuote::HasQuote { index }) => { + if char == '\\' { + // We've seen a \" pattern, not the start of the string + Some(FoundStartingQuote::HasBackslachThenQuote { quote_index: index }) + } else { + break Some(index); + } + } + Some(FoundStartingQuote::HasBackslachThenQuote { quote_index }) => { + if char == '\\' { + // \\" pattern, the quote is unescaped + break Some(quote_index); + } else { + None + } + } + None => { + if char == '"' { + Some(FoundStartingQuote::HasQuote { index: i }) + } else { + None + } + } + } + } else { + break None; + } + }?; + + let before_json = &key[..found_starting_quote]; + let mut result: String = serde_json::from_str(&key[found_starting_quote..]).ok()?; + + result.insert_str(0, before_json); + Some(result) +} + +fn powersync_remove_duplicate_key_encoding_impl( + _ctx: *mut sqlite::context, + args: &[*mut sqlite::value], +) -> Result, SQLiteError> { + let arg = args.get(0).ok_or(ResultCode::MISUSE)?; + + if arg.value_type() != ColumnType::Text { + return Err(ResultCode::MISMATCH.into()); + } + + return Ok(remove_duplicate_key_encoding(arg.text())); +} + +create_sqlite_optional_text_fn!( + powersync_remove_duplicate_key_encoding, + powersync_remove_duplicate_key_encoding_impl, + "powersync_remove_duplicate_key_encoding" +); + +pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { + db.create_function_v2( + "powersync_remove_duplicate_key_encoding", + 1, + sqlite::UTF8 | sqlite::DETERMINISTIC, + None, + Some(powersync_remove_duplicate_key_encoding), + None, + None, + None, + )?; + Ok(()) +} + +#[cfg(test)] +mod test { + use core::assert_matches::assert_matches; + + use super::remove_duplicate_key_encoding; + + fn assert_unaffected(source: &str) { + assert_matches!(remove_duplicate_key_encoding(source), None); + } + + #[test] + fn does_not_change_unaffected_keys() { + assert_unaffected("object_type/object_id/subkey"); + assert_unaffected("object_type/object_id/null"); + + // Object type and ID could technically contain quotes and forward slashes + assert_unaffected(r#""object"/"type"/subkey"#); + assert_unaffected("object\"/type/object\"/id/subkey"); + + // Invalid key, but we shouldn't crash + assert_unaffected("\"key\""); + } + + #[test] + fn removes_quotes() { + assert_eq!( + remove_duplicate_key_encoding("foo/bar/\"baz\"").unwrap(), + "foo/bar/baz", + ); + + assert_eq!( + remove_duplicate_key_encoding(r#"foo/bar/"nested/subkey""#).unwrap(), + "foo/bar/nested/subkey" + ); + + assert_eq!( + remove_duplicate_key_encoding(r#"foo/bar/"escaped\"key""#).unwrap(), + "foo/bar/escaped\"key" + ); + assert_eq!( + remove_duplicate_key_encoding(r#"foo/bar/"escaped\\key""#).unwrap(), + "foo/bar/escaped\\key" + ); + assert_eq!( + remove_duplicate_key_encoding(r#"foo/bar/"/\\"subkey""#).unwrap(), + "foo/bar/\"/\\\\subkey" + ); + } +} diff --git a/crates/core/src/json_merge.rs b/crates/core/src/json_merge.rs index 80c1687..cb31479 100644 --- a/crates/core/src/json_merge.rs +++ b/crates/core/src/json_merge.rs @@ -1,6 +1,5 @@ extern crate alloc; -use alloc::format; use alloc::string::{String, ToString}; use core::ffi::c_int; diff --git a/crates/core/src/kv.rs b/crates/core/src/kv.rs index 811409d..3b84791 100644 --- a/crates/core/src/kv.rs +++ b/crates/core/src/kv.rs @@ -8,10 +8,10 @@ use sqlite::ResultCode; use sqlite_nostd as sqlite; use sqlite_nostd::{Connection, Context}; -use crate::bucket_priority::BucketPriority; use crate::create_sqlite_optional_text_fn; use crate::create_sqlite_text_fn; use crate::error::SQLiteError; +use crate::sync::BucketPriority; fn powersync_client_id_impl( ctx: *mut sqlite::context, @@ -19,17 +19,21 @@ fn powersync_client_id_impl( ) -> Result { let db = ctx.db_handle(); + client_id(db) +} + +pub fn client_id(db: *mut sqlite::sqlite3) -> Result { // language=SQLite let statement = db.prepare_v2("select value from ps_kv where key = 'client_id'")?; if statement.step()? == ResultCode::ROW { let client_id = statement.column_text(0)?; - return Ok(client_id.to_string()); + Ok(client_id.to_string()) } else { - return Err(SQLiteError( + Err(SQLiteError( ResultCode::ABORT, Some(format!("No client_id found in ps_kv")), - )); + )) } } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index b191fd0..76edd45 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -2,8 +2,8 @@ #![feature(vec_into_raw_parts)] #![allow(internal_features)] #![feature(core_intrinsics)] -#![feature(error_in_core)] #![feature(assert_matches)] +#![feature(strict_overflow_ops)] extern crate alloc; @@ -12,13 +12,13 @@ use core::ffi::{c_char, c_int}; use sqlite::ResultCode; use sqlite_nostd as sqlite; -mod bucket_priority; +mod bson; mod checkpoint; mod crud_vtab; mod diff; mod error; mod ext; -mod fix035; +mod fix_data; mod json_merge; mod kv; mod macros; @@ -26,8 +26,8 @@ mod migrations; mod operations; mod operations_vtab; mod schema; +mod sync; mod sync_local; -mod sync_types; mod util; mod uuid; mod version; @@ -57,10 +57,12 @@ fn init_extension(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { crate::views::register(db)?; crate::uuid::register(db)?; crate::diff::register(db)?; + crate::fix_data::register(db)?; crate::json_merge::register(db)?; crate::view_admin::register(db)?; crate::checkpoint::register(db)?; crate::kv::register(db)?; + sync::register(db)?; crate::schema::register(db)?; crate::operations_vtab::register(db)?; diff --git a/crates/core/src/macros.rs b/crates/core/src/macros.rs index aeb1f1a..459a5ab 100644 --- a/crates/core/src/macros.rs +++ b/crates/core/src/macros.rs @@ -11,18 +11,7 @@ macro_rules! create_sqlite_text_fn { let result = $fn_impl_name(ctx, args); if let Err(err) = result { - let SQLiteError(code, message) = SQLiteError::from(err); - if message.is_some() { - ctx.result_error(&format!("{:} {:}", $description, message.unwrap())); - } else { - let error = ctx.db_handle().errmsg().unwrap(); - if error == "not an error" { - ctx.result_error(&format!("{:}", $description)); - } else { - ctx.result_error(&format!("{:} {:}", $description, error)); - } - } - ctx.result_error_code(code); + SQLiteError::from(err).apply_to_ctx($description, ctx); } else if let Ok(r) = result { ctx.result_text_transient(&r); } @@ -43,18 +32,7 @@ macro_rules! create_sqlite_optional_text_fn { let result = $fn_impl_name(ctx, args); if let Err(err) = result { - let SQLiteError(code, message) = SQLiteError::from(err); - if message.is_some() { - ctx.result_error(&format!("{:} {:}", $description, message.unwrap())); - } else { - let error = ctx.db_handle().errmsg().unwrap(); - if error == "not an error" { - ctx.result_error(&format!("{:}", $description)); - } else { - ctx.result_error(&format!("{:} {:}", $description, error)); - } - } - ctx.result_error_code(code); + SQLiteError::from(err).apply_to_ctx($description, ctx); } else if let Ok(r) = result { if let Some(s) = r { ctx.result_text_transient(&s); diff --git a/crates/core/src/migrations.rs b/crates/core/src/migrations.rs index 7241957..8fd5357 100644 --- a/crates/core/src/migrations.rs +++ b/crates/core/src/migrations.rs @@ -8,9 +8,9 @@ use sqlite::ResultCode; use sqlite_nostd as sqlite; use sqlite_nostd::{Connection, Context}; -use crate::bucket_priority::BucketPriority; use crate::error::{PSResult, SQLiteError}; -use crate::fix035::apply_v035_fix; +use crate::fix_data::apply_v035_fix; +use crate::sync::BucketPriority; pub const LATEST_VERSION: i32 = 9; diff --git a/crates/core/src/operations.rs b/crates/core/src/operations.rs index c0b7622..0c031b9 100644 --- a/crates/core/src/operations.rs +++ b/crates/core/src/operations.rs @@ -1,7 +1,9 @@ -use alloc::format; -use alloc::string::String; - -use crate::error::{PSResult, SQLiteError}; +use crate::error::SQLiteError; +use crate::sync::line::DataLine; +use crate::sync::operations::insert_bucket_operations; +use crate::sync::storage_adapter::StorageAdapter; +use alloc::vec::Vec; +use serde::Deserialize; use sqlite_nostd as sqlite; use sqlite_nostd::{Connection, ResultCode}; @@ -9,246 +11,17 @@ use crate::ext::SafeManagedStmt; // Run inside a transaction pub fn insert_operation(db: *mut sqlite::sqlite3, data: &str) -> Result<(), SQLiteError> { - // language=SQLite - let statement = db.prepare_v2( - "\ -SELECT - json_extract(e.value, '$.bucket') as bucket, - json_extract(e.value, '$.data') as data, - json_extract(e.value, '$.has_more') as has_more, - json_extract(e.value, '$.after') as after, - json_extract(e.value, '$.next_after') as next_after -FROM json_each(json_extract(?1, '$.buckets')) e", - )?; - statement.bind_text(1, data, sqlite::Destructor::STATIC)?; - - while statement.step()? == ResultCode::ROW { - let bucket = statement.column_text(0)?; - let data = statement.column_text(1)?; - // let _has_more = statement.column_int(2)? != 0; - // let _after = statement.column_text(3)?; - // let _next_after = statement.column_text(4)?; - - insert_bucket_operations(db, bucket, data)?; - } - - Ok(()) -} - -pub fn insert_bucket_operations( - db: *mut sqlite::sqlite3, - bucket: &str, - data: &str, -) -> Result<(), SQLiteError> { - // Statement to insert new operations (only for PUT and REMOVE). - // language=SQLite - let iterate_statement = db.prepare_v2( - "\ -SELECT - json_extract(e.value, '$.op_id') as op_id, - json_extract(e.value, '$.op') as op, - json_extract(e.value, '$.object_type') as object_type, - json_extract(e.value, '$.object_id') as object_id, - json_extract(e.value, '$.checksum') as checksum, - json_extract(e.value, '$.data') as data, - json_extract(e.value, '$.subkey') as subkey -FROM json_each(?) e", - )?; - iterate_statement.bind_text(1, data, sqlite::Destructor::STATIC)?; - - // We do an ON CONFLICT UPDATE simply so that the RETURNING bit works for existing rows. - // We can consider splitting this into separate SELECT and INSERT statements. - // language=SQLite - let bucket_statement = db.prepare_v2( - "INSERT INTO ps_buckets(name) - VALUES(?) - ON CONFLICT DO UPDATE - SET last_applied_op = last_applied_op - RETURNING id, last_applied_op", - )?; - bucket_statement.bind_text(1, bucket, sqlite::Destructor::STATIC)?; - bucket_statement.step()?; - - let bucket_id = bucket_statement.column_int64(0); - - // This is an optimization for initial sync - we can avoid persisting individual REMOVE - // operations when last_applied_op = 0. - // We do still need to do the "supersede_statement" step for this case, since a REMOVE - // operation can supersede another PUT operation we're syncing at the same time. - let mut is_empty = bucket_statement.column_int64(1) == 0; - - // Statement to supersede (replace) operations with the same key. - // language=SQLite - let supersede_statement = db.prepare_v2( - "\ -DELETE FROM ps_oplog - WHERE unlikely(ps_oplog.bucket = ?1) - AND ps_oplog.key = ?2 -RETURNING op_id, hash", - )?; - supersede_statement.bind_int64(1, bucket_id)?; - - // language=SQLite - let insert_statement = db.prepare_v2("\ -INSERT INTO ps_oplog(bucket, op_id, key, row_type, row_id, data, hash) VALUES (?, ?, ?, ?, ?, ?, ?)")?; - insert_statement.bind_int64(1, bucket_id)?; - - let updated_row_statement = db.prepare_v2( - "\ -INSERT OR IGNORE INTO ps_updated_rows(row_type, row_id) VALUES(?1, ?2)", - )?; - - bucket_statement.reset()?; - - let mut last_op: Option = None; - let mut add_checksum: i32 = 0; - let mut op_checksum: i32 = 0; - let mut added_ops: i32 = 0; - - while iterate_statement.step()? == ResultCode::ROW { - let op_id = iterate_statement.column_int64(0); - let op = iterate_statement.column_text(1)?; - let object_type = iterate_statement.column_text(2); - let object_id = iterate_statement.column_text(3); - let checksum = iterate_statement.column_int(4); - let op_data = iterate_statement.column_text(5); - - last_op = Some(op_id); - added_ops += 1; - - if op == "PUT" || op == "REMOVE" { - let key: String; - if let (Ok(object_type), Ok(object_id)) = (object_type.as_ref(), object_id.as_ref()) { - let subkey = iterate_statement.column_text(6).unwrap_or("null"); - key = format!("{}/{}/{}", &object_type, &object_id, subkey); - } else { - key = String::from(""); - } - - supersede_statement.bind_text(2, &key, sqlite::Destructor::STATIC)?; - - let mut superseded = false; - - while supersede_statement.step()? == ResultCode::ROW { - // Superseded (deleted) a previous operation, add the checksum - let supersede_checksum = supersede_statement.column_int(1); - add_checksum = add_checksum.wrapping_add(supersede_checksum); - op_checksum = op_checksum.wrapping_sub(supersede_checksum); - - // Superseded an operation, only skip if the bucket was empty - // Previously this checked "superseded_op <= last_applied_op". - // However, that would not account for a case where a previous - // PUT operation superseded the original PUT operation in this - // same batch, in which case superseded_op is not accurate for this. - if !is_empty { - superseded = true; - } - } - supersede_statement.reset()?; - - if op == "REMOVE" { - let should_skip_remove = !superseded; - - add_checksum = add_checksum.wrapping_add(checksum); - - if !should_skip_remove { - if let (Ok(object_type), Ok(object_id)) = (object_type, object_id) { - updated_row_statement.bind_text( - 1, - object_type, - sqlite::Destructor::STATIC, - )?; - updated_row_statement.bind_text( - 2, - object_id, - sqlite::Destructor::STATIC, - )?; - updated_row_statement.exec()?; - } - } - - continue; - } - - insert_statement.bind_int64(2, op_id)?; - if key != "" { - insert_statement.bind_text(3, &key, sqlite::Destructor::STATIC)?; - } else { - insert_statement.bind_null(3)?; - } - - if let (Ok(object_type), Ok(object_id)) = (object_type, object_id) { - insert_statement.bind_text(4, object_type, sqlite::Destructor::STATIC)?; - insert_statement.bind_text(5, object_id, sqlite::Destructor::STATIC)?; - } else { - insert_statement.bind_null(4)?; - insert_statement.bind_null(5)?; - } - if let Ok(data) = op_data { - insert_statement.bind_text(6, data, sqlite::Destructor::STATIC)?; - } else { - insert_statement.bind_null(6)?; - } - - insert_statement.bind_int(7, checksum)?; - insert_statement.exec()?; - - op_checksum = op_checksum.wrapping_add(checksum); - } else if op == "MOVE" { - add_checksum = add_checksum.wrapping_add(checksum); - } else if op == "CLEAR" { - // Any remaining PUT operations should get an implicit REMOVE - // language=SQLite - let clear_statement1 = db - .prepare_v2( - "INSERT OR IGNORE INTO ps_updated_rows(row_type, row_id) -SELECT row_type, row_id -FROM ps_oplog -WHERE bucket = ?1", - ) - .into_db_result(db)?; - clear_statement1.bind_int64(1, bucket_id)?; - clear_statement1.exec()?; - - let clear_statement2 = db - .prepare_v2("DELETE FROM ps_oplog WHERE bucket = ?1") - .into_db_result(db)?; - clear_statement2.bind_int64(1, bucket_id)?; - clear_statement2.exec()?; - - // And we need to re-apply all of those. - // We also replace the checksum with the checksum of the CLEAR op. - // language=SQLite - let clear_statement2 = db.prepare_v2( - "UPDATE ps_buckets SET last_applied_op = 0, add_checksum = ?1, op_checksum = 0 WHERE id = ?2", - )?; - clear_statement2.bind_int64(2, bucket_id)?; - clear_statement2.bind_int(1, checksum)?; - clear_statement2.exec()?; - - add_checksum = 0; - is_empty = true; - op_checksum = 0; - } + #[derive(Deserialize)] + struct BucketBatch<'a> { + #[serde(borrow)] + buckets: Vec>, } - if let Some(last_op) = &last_op { - // language=SQLite - let statement = db.prepare_v2( - "UPDATE ps_buckets - SET last_op = ?2, - add_checksum = (add_checksum + ?3) & 0xffffffff, - op_checksum = (op_checksum + ?4) & 0xffffffff, - count_since_last = count_since_last + ?5 - WHERE id = ?1", - )?; - statement.bind_int64(1, bucket_id)?; - statement.bind_int64(2, *last_op)?; - statement.bind_int(3, add_checksum)?; - statement.bind_int(4, op_checksum)?; - statement.bind_int(5, added_ops)?; + let batch: BucketBatch = serde_json::from_str(data)?; + let adapter = StorageAdapter::new(db)?; - statement.exec()?; + for line in &batch.buckets { + insert_bucket_operations(&adapter, &line)?; } Ok(()) diff --git a/crates/core/src/schema/management.rs b/crates/core/src/schema/management.rs index 6f66d03..bccba3f 100644 --- a/crates/core/src/schema/management.rs +++ b/crates/core/src/schema/management.rs @@ -1,8 +1,8 @@ extern crate alloc; -use alloc::format; use alloc::string::String; use alloc::vec::Vec; +use alloc::{format, vec}; use core::ffi::c_int; use sqlite::{Connection, ResultCode, Value}; @@ -14,6 +14,8 @@ use crate::ext::ExtendedDatabase; use crate::util::{quote_identifier, quote_json_path}; use crate::{create_auto_tx_function, create_sqlite_text_fn}; +use super::Schema; + fn update_tables(db: *mut sqlite::sqlite3, schema: &str) -> Result<(), SQLiteError> { { // In a block so that the statement is finalized before dropping tables @@ -138,87 +140,83 @@ SELECT name, internal_name, local_only FROM powersync_tables WHERE name NOT IN ( fn update_indexes(db: *mut sqlite::sqlite3, schema: &str) -> Result<(), SQLiteError> { let mut statements: Vec = alloc::vec![]; + let schema = serde_json::from_str::(schema)?; + let mut expected_index_names: Vec = vec![]; { // In a block so that the statement is finalized before dropping indexes // language=SQLite - let statement = db.prepare_v2("\ -SELECT - powersync_internal_table_name(tables.value) as table_name, - (powersync_internal_table_name(tables.value) || '__' || json_extract(indexes.value, '$.name')) as index_name, - json_extract(indexes.value, '$.columns') as index_columns, - ifnull(sqlite_master.sql, '') as sql - FROM json_each(json_extract(?, '$.tables')) tables - CROSS JOIN json_each(json_extract(tables.value, '$.indexes')) indexes - LEFT JOIN sqlite_master ON sqlite_master.name = index_name AND sqlite_master.type = 'index' - ").into_db_result(db)?; - statement.bind_text(1, schema, sqlite::Destructor::STATIC)?; + let find_index = + db.prepare_v2("SELECT sql FROM sqlite_master WHERE name = ? AND type = 'index'")?; - while statement.step().into_db_result(db)? == ResultCode::ROW { - let table_name = statement.column_text(0)?; - let index_name = statement.column_text(1)?; - let columns = statement.column_text(2)?; - let existing_sql = statement.column_text(3)?; - - // language=SQLite - let stmt2 = db.prepare_v2("select json_extract(e.value, '$.name') as name, json_extract(e.value, '$.type') as type, json_extract(e.value, '$.ascending') as ascending from json_each(?) e")?; - stmt2.bind_text(1, columns, sqlite::Destructor::STATIC)?; - - let mut column_values: Vec = alloc::vec![]; - while stmt2.step()? == ResultCode::ROW { - let name = stmt2.column_text(0)?; - let type_name = stmt2.column_text(1)?; - let ascending = stmt2.column_int(2) != 0; - - if ascending { - let value = format!( + for table in &schema.tables { + let table_name = table.internal_name(); + + for index in &table.indexes { + let index_name = format!("{}__{}", table_name, &index.name); + + let existing_sql = { + find_index.reset()?; + find_index.bind_text(1, &index_name, sqlite::Destructor::STATIC)?; + + let result = if let ResultCode::ROW = find_index.step()? { + Some(find_index.column_text(0)?) + } else { + None + }; + + result + }; + + let mut column_values: Vec = alloc::vec![]; + for indexed_column in &index.columns { + let mut value = format!( "CAST(json_extract(data, {:}) as {:})", - quote_json_path(name), - type_name - ); - column_values.push(value); - } else { - let value = format!( - "CAST(json_extract(data, {:}) as {:}) DESC", - quote_json_path(name), - type_name + quote_json_path(&indexed_column.name), + &indexed_column.type_name ); + + if !indexed_column.ascending { + value += " DESC"; + } + column_values.push(value); } - } - let sql = format!( - "CREATE INDEX {} ON {}({})", - quote_identifier(index_name), - quote_identifier(table_name), - column_values.join(", ") - ); - if existing_sql == "" { - statements.push(sql); - } else if existing_sql != sql { - statements.push(format!("DROP INDEX {}", quote_identifier(index_name))); - statements.push(sql); + let sql = format!( + "CREATE INDEX {} ON {}({})", + quote_identifier(&index_name), + quote_identifier(&table_name), + column_values.join(", ") + ); + + if existing_sql.is_none() { + statements.push(sql); + } else if existing_sql != Some(&sql) { + statements.push(format!("DROP INDEX {}", quote_identifier(&index_name))); + statements.push(sql); + } + + expected_index_names.push(index_name); } } // In a block so that the statement is finalized before dropping indexes // language=SQLite - let statement = db.prepare_v2("\ -WITH schema_indexes AS ( -SELECT - powersync_internal_table_name(tables.value) as table_name, - (powersync_internal_table_name(tables.value) || '__' || json_extract(indexes.value, '$.name')) as index_name - FROM json_each(json_extract(?1, '$.tables')) tables - CROSS JOIN json_each(json_extract(tables.value, '$.indexes')) indexes -) + let statement = db + .prepare_v2( + "\ SELECT sqlite_master.name as index_name FROM sqlite_master WHERE sqlite_master.type = 'index' AND sqlite_master.name GLOB 'ps_data_*' - AND sqlite_master.name NOT IN (SELECT index_name FROM schema_indexes) -").into_db_result(db)?; - statement.bind_text(1, schema, sqlite::Destructor::STATIC)?; + AND sqlite_master.name NOT IN (SELECT value FROM json_each(?)) +", + ) + .into_db_result(db)?; + let json_names = serde_json::to_string(&expected_index_names)?; + statement.bind_text(1, &json_names, sqlite::Destructor::STATIC)?; while statement.step()? == ResultCode::ROW { let name = statement.column_text(0)?; diff --git a/crates/core/src/schema/mod.rs b/crates/core/src/schema/mod.rs index a0a277e..96fb732 100644 --- a/crates/core/src/schema/mod.rs +++ b/crates/core/src/schema/mod.rs @@ -1,11 +1,16 @@ mod management; mod table_info; +use alloc::vec::Vec; +use serde::Deserialize; use sqlite::ResultCode; use sqlite_nostd as sqlite; -pub use table_info::{ - ColumnInfo, ColumnNameAndTypeStatement, DiffIncludeOld, TableInfo, TableInfoFlags, -}; +pub use table_info::{DiffIncludeOld, Table, TableInfoFlags}; + +#[derive(Deserialize)] +pub struct Schema { + tables: Vec, +} pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { management::register(db) diff --git a/crates/core/src/schema/table_info.rs b/crates/core/src/schema/table_info.rs index 0bfbfa5..4224221 100644 --- a/crates/core/src/schema/table_info.rs +++ b/crates/core/src/schema/table_info.rs @@ -1,103 +1,127 @@ -use core::marker::PhantomData; +use alloc::{format, string::String, vec, vec::Vec}; +use serde::{de::Visitor, Deserialize}; -use alloc::{ - string::{String, ToString}, - vec::Vec, -}; -use streaming_iterator::StreamingIterator; - -use crate::error::SQLiteError; -use sqlite::{Connection, ResultCode}; -use sqlite_nostd::{self as sqlite, ManagedStmt}; - -pub struct TableInfo { +#[derive(Deserialize)] +pub struct Table { pub name: String, - pub view_name: String, + #[serde(rename = "view_name")] + pub view_name_override: Option, + pub columns: Vec, + #[serde(default)] + pub indexes: Vec, + #[serde( + default, + rename = "include_old", + deserialize_with = "deserialize_include_old" + )] pub diff_include_old: Option, + #[serde(flatten)] pub flags: TableInfoFlags, } -impl TableInfo { - pub fn parse_from(db: *mut sqlite::sqlite3, data: &str) -> Result { - // language=SQLite - let statement = db.prepare_v2( - "SELECT - json_extract(?1, '$.name'), - ifnull(json_extract(?1, '$.view_name'), json_extract(?1, '$.name')), - json_extract(?1, '$.local_only'), - json_extract(?1, '$.insert_only'), - json_extract(?1, '$.include_old'), - json_extract(?1, '$.include_metadata'), - json_extract(?1, '$.include_old_only_when_changed'), - json_extract(?1, '$.ignore_empty_update')", - )?; - statement.bind_text(1, data, sqlite::Destructor::STATIC)?; - - let step_result = statement.step()?; - if step_result != ResultCode::ROW { - return Err(SQLiteError::from(ResultCode::SCHEMA)); - } +impl Table { + pub fn from_json(text: &str) -> Result { + serde_json::from_str(text) + } - let name = statement.column_text(0)?.to_string(); - let view_name = statement.column_text(1)?.to_string(); - let flags = { - let local_only = statement.column_int(2) != 0; - let insert_only = statement.column_int(3) != 0; - let include_metadata = statement.column_int(5) != 0; - let include_old_only_when_changed = statement.column_int(6) != 0; - let ignore_empty_update = statement.column_int(7) != 0; - - let mut flags = TableInfoFlags::default(); - flags = flags.set_flag(TableInfoFlags::LOCAL_ONLY, local_only); - flags = flags.set_flag(TableInfoFlags::INSERT_ONLY, insert_only); - flags = flags.set_flag(TableInfoFlags::INCLUDE_METADATA, include_metadata); - flags = flags.set_flag( - TableInfoFlags::INCLUDE_OLD_ONLY_WHEN_CHANGED, - include_old_only_when_changed, - ); - flags = flags.set_flag(TableInfoFlags::IGNORE_EMPTY_UPDATE, ignore_empty_update); - flags - }; - - let include_old = match statement.column_type(4)? { - sqlite_nostd::ColumnType::Text => { - let columns: Vec = serde_json::from_str(statement.column_text(4)?)?; - Some(DiffIncludeOld::OnlyForColumns { columns }) - } + pub fn view_name(&self) -> &str { + self.view_name_override + .as_deref() + .unwrap_or(self.name.as_str()) + } - sqlite_nostd::ColumnType::Integer => { - if statement.column_int(4) != 0 { - Some(DiffIncludeOld::ForAllColumns) - } else { - None - } - } - _ => None, - }; - - // Don't allow include_metadata for local_only tables, it breaks our trigger setup and makes - // no sense because these changes are never inserted into ps_crud. - if flags.include_metadata() && flags.local_only() { - return Err(SQLiteError( - ResultCode::ERROR, - Some("include_metadata and local_only are incompatible".to_string()), - )); + pub fn internal_name(&self) -> String { + if self.flags.local_only() { + format!("ps_data_local__{:}", self.name) + } else { + format!("ps_data__{:}", self.name) } + } - return Ok(TableInfo { - name, - view_name, - diff_include_old: include_old, - flags, - }); + pub fn column_names(&self) -> impl Iterator { + self.columns.iter().map(|c| c.name.as_str()) } } +#[derive(Deserialize)] +pub struct Column { + pub name: String, + #[serde(rename = "type")] + pub type_name: String, +} + +#[derive(Deserialize)] +pub struct Index { + pub name: String, + pub columns: Vec, +} + +#[derive(Deserialize)] +pub struct IndexedColumn { + pub name: String, + pub ascending: bool, + #[serde(rename = "type")] + pub type_name: String, +} + pub enum DiffIncludeOld { OnlyForColumns { columns: Vec }, ForAllColumns, } +fn deserialize_include_old<'de, D: serde::Deserializer<'de>>( + deserializer: D, +) -> Result, D::Error> { + struct IncludeOldVisitor; + + impl<'de> Visitor<'de> for IncludeOldVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(formatter, "an array of columns, or true") + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + return Ok(None); + } + + fn visit_bool(self, v: bool) -> Result + where + E: serde::de::Error, + { + Ok(if v { + Some(DiffIncludeOld::ForAllColumns) + } else { + None + }) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut elements: Vec = vec![]; + while let Some(next) = seq.next_element::()? { + elements.push(next); + } + + Ok(Some(DiffIncludeOld::OnlyForColumns { columns: elements })) + } + } + + deserializer.deserialize_option(IncludeOldVisitor) +} + #[derive(Clone, Copy)] #[repr(transparent)] pub struct TableInfoFlags(pub u32); @@ -148,53 +172,56 @@ impl Default for TableInfoFlags { } } -pub struct ColumnNameAndTypeStatement<'a> { - pub stmt: ManagedStmt, - table: PhantomData<&'a str>, -} +impl<'de> Deserialize<'de> for TableInfoFlags { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct FlagsVisitor; -impl ColumnNameAndTypeStatement<'_> { - pub fn new(db: *mut sqlite::sqlite3, table: &str) -> Result { - let stmt = db.prepare_v2("select json_extract(e.value, '$.name'), json_extract(e.value, '$.type') from json_each(json_extract(?, '$.columns')) e")?; - stmt.bind_text(1, table, sqlite::Destructor::STATIC)?; + impl<'de> Visitor<'de> for FlagsVisitor { + type Value = TableInfoFlags; - Ok(Self { - stmt, - table: PhantomData, - }) - } + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(formatter, "an object with table flags") + } - fn step(stmt: &ManagedStmt) -> Result, ResultCode> { - if stmt.step()? == ResultCode::ROW { - let name = stmt.column_text(0)?; - let type_name = stmt.column_text(1)?; + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut flags = TableInfoFlags::default(); + + while let Some((key, value)) = map.next_entry::<&'de str, bool>()? { + flags = flags.set_flag( + match key { + "local_only" => TableInfoFlags::LOCAL_ONLY, + "insert_only" => TableInfoFlags::INSERT_ONLY, + "include_metadata" => TableInfoFlags::INCLUDE_METADATA, + "include_old_only_when_changed" => { + TableInfoFlags::INCLUDE_OLD_ONLY_WHEN_CHANGED + } + "ignore_empty_update" => TableInfoFlags::IGNORE_EMPTY_UPDATE, + _ => continue, + }, + value, + ); + } - return Ok(Some(ColumnInfo { name, type_name })); + Ok(flags) + } } - Ok(None) + deserializer.deserialize_struct( + "TableInfoFlags", + &[ + "local_only", + "insert_only", + "include_metadata", + "include_old_only_when_changed", + "ignore_empty_update", + ], + FlagsVisitor, + ) } - - pub fn streaming_iter( - &mut self, - ) -> impl StreamingIterator> { - streaming_iterator::from_fn(|| match Self::step(&self.stmt) { - Err(e) => Some(Err(e)), - Ok(Some(other)) => Some(Ok(other)), - Ok(None) => None, - }) - } - - pub fn names_iter(&mut self) -> impl StreamingIterator> { - self.streaming_iter().map(|item| match item { - Ok(row) => Ok(row.name), - Err(e) => Err(*e), - }) - } -} - -#[derive(Clone)] -pub struct ColumnInfo<'a> { - pub name: &'a str, - pub type_name: &'a str, } diff --git a/crates/core/src/bucket_priority.rs b/crates/core/src/sync/bucket_priority.rs similarity index 79% rename from crates/core/src/bucket_priority.rs rename to crates/core/src/sync/bucket_priority.rs index 454f1fe..a69f2a6 100644 --- a/crates/core/src/bucket_priority.rs +++ b/crates/core/src/sync/bucket_priority.rs @@ -1,10 +1,10 @@ -use serde::{de::Visitor, Deserialize}; +use serde::{de::Visitor, Deserialize, Serialize}; use sqlite_nostd::ResultCode; use crate::error::SQLiteError; #[repr(transparent)] -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct BucketPriority { pub number: i32, } @@ -14,6 +14,8 @@ impl BucketPriority { self == BucketPriority::HIGHEST } + /// The priority to use when the sync service doesn't attach priorities in checkpoints. + pub const FALLBACK: BucketPriority = BucketPriority { number: 3 }; pub const HIGHEST: BucketPriority = BucketPriority { number: 0 }; /// A low priority used to represent fully-completed sync operations across all priorities. @@ -43,7 +45,13 @@ impl Into for BucketPriority { impl PartialOrd for BucketPriority { fn partial_cmp(&self, other: &BucketPriority) -> Option { - Some(self.number.partial_cmp(&other.number)?.reverse()) + Some(self.cmp(other)) + } +} + +impl Ord for BucketPriority { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.number.cmp(&other.number).reverse() } } @@ -87,3 +95,12 @@ impl<'de> Deserialize<'de> for BucketPriority { deserializer.deserialize_i32(PriorityVisitor) } } + +impl Serialize for BucketPriority { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_i32(self.number) + } +} diff --git a/crates/core/src/sync/checkpoint.rs b/crates/core/src/sync/checkpoint.rs new file mode 100644 index 0000000..57c9b7c --- /dev/null +++ b/crates/core/src/sync/checkpoint.rs @@ -0,0 +1,91 @@ +use alloc::{string::String, vec::Vec}; +use num_traits::Zero; + +use crate::{ + error::SQLiteError, + sync::{line::BucketChecksum, BucketPriority, Checksum}, +}; +use sqlite_nostd::{self as sqlite, Connection, ResultCode}; + +/// A structure cloned from [BucketChecksum]s with an owned bucket name instead of one borrowed from +/// a sync line. +#[derive(Debug, Clone)] +pub struct OwnedBucketChecksum { + pub bucket: String, + pub checksum: Checksum, + pub priority: BucketPriority, + pub count: Option, +} + +impl OwnedBucketChecksum { + pub fn is_in_priority(&self, prio: Option) -> bool { + match prio { + None => true, + Some(prio) => self.priority >= prio, + } + } +} + +impl From<&'_ BucketChecksum<'_>> for OwnedBucketChecksum { + fn from(value: &'_ BucketChecksum<'_>) -> Self { + Self { + bucket: value.bucket.clone().into_owned(), + checksum: value.checksum, + priority: value.priority.unwrap_or(BucketPriority::FALLBACK), + count: value.count, + } + } +} + +pub struct ChecksumMismatch { + pub bucket_name: String, + pub expected_checksum: Checksum, + pub actual_op_checksum: Checksum, + pub actual_add_checksum: Checksum, +} + +pub fn validate_checkpoint<'a>( + buckets: impl Iterator, + priority: Option, + db: *mut sqlite::sqlite3, +) -> Result, SQLiteError> { + // language=SQLite + let statement = db.prepare_v2( + " +SELECT + ps_buckets.add_checksum as add_checksum, + ps_buckets.op_checksum as oplog_checksum +FROM ps_buckets WHERE name = ?;", + )?; + + let mut failures: Vec = Vec::new(); + for bucket in buckets { + if bucket.is_in_priority(priority) { + statement.bind_text(1, &bucket.bucket, sqlite_nostd::Destructor::STATIC)?; + + let (add_checksum, oplog_checksum) = match statement.step()? { + ResultCode::ROW => { + let add_checksum = Checksum::from_i32(statement.column_int(0)); + let oplog_checksum = Checksum::from_i32(statement.column_int(1)); + (add_checksum, oplog_checksum) + } + _ => (Checksum::zero(), Checksum::zero()), + }; + + let actual = add_checksum + oplog_checksum; + + if actual != bucket.checksum { + failures.push(ChecksumMismatch { + bucket_name: bucket.bucket.clone(), + expected_checksum: bucket.checksum, + actual_add_checksum: add_checksum, + actual_op_checksum: oplog_checksum, + }); + } + + statement.reset()?; + } + } + + Ok(failures) +} diff --git a/crates/core/src/sync/checksum.rs b/crates/core/src/sync/checksum.rs new file mode 100644 index 0000000..c6f2bc6 --- /dev/null +++ b/crates/core/src/sync/checksum.rs @@ -0,0 +1,214 @@ +use core::{ + fmt::Display, + num::Wrapping, + ops::{Add, AddAssign, Sub, SubAssign}, +}; + +use num_traits::float::FloatCore; +use num_traits::Zero; +use serde::{de::Visitor, Deserialize, Serialize}; + +/// A checksum as received from the sync service. +/// +/// Conceptually, we use unsigned 32 bit integers to represent checksums, and adding checksums +/// should be a wrapping add. +#[repr(transparent)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)] +pub struct Checksum(Wrapping); + +impl Checksum { + pub const fn value(self) -> u32 { + self.0 .0 + } + + pub const fn from_value(value: u32) -> Self { + Self(Wrapping(value)) + } + + pub const fn from_i32(value: i32) -> Self { + Self::from_value(value as u32) + } + + pub const fn bitcast_i32(self) -> i32 { + self.value() as i32 + } +} + +impl Zero for Checksum { + fn zero() -> Self { + const { Self::from_value(0) } + } + + fn is_zero(&self) -> bool { + self.value() == 0 + } +} + +impl Add for Checksum { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0) + } +} + +impl AddAssign for Checksum { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0 + } +} + +impl Sub for Checksum { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0 - rhs.0) + } +} + +impl SubAssign for Checksum { + fn sub_assign(&mut self, rhs: Self) { + self.0 -= rhs.0; + } +} + +impl From for Checksum { + fn from(value: u32) -> Self { + Self::from_value(value) + } +} + +impl Display for Checksum { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:#010x}", self.value()) + } +} + +impl<'de> Deserialize<'de> for Checksum { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct MyVisitor; + + impl<'de> Visitor<'de> for MyVisitor { + type Value = Checksum; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(formatter, "a number to interpret as a checksum") + } + + fn visit_u32(self, v: u32) -> Result + where + E: serde::de::Error, + { + Ok(v.into()) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + let as_u32: u32 = v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Unsigned(v), &"a 32-bit int") + })?; + Ok(as_u32.into()) + } + + fn visit_i32(self, v: i32) -> Result + where + E: serde::de::Error, + { + Ok(Checksum::from_i32(v)) + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + // This is supposed to be an u32, but it could also be a i32 that we need to + // normalize. + let min: i64 = u32::MIN.into(); + let max: i64 = u32::MAX.into(); + + if v >= min && v <= max { + return Ok(Checksum::from(v as u32)); + } + + let as_i32: i32 = v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Signed(v), &"a 32-bit int") + })?; + Ok(Checksum::from_i32(as_i32)) + } + + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + if !v.is_finite() || f64::trunc(v) != v { + return Err(E::invalid_value( + serde::de::Unexpected::Float(v), + &"a whole number", + )); + } + + self.visit_i64(v as i64) + } + } + + deserializer.deserialize_u32(MyVisitor) + } +} + +#[cfg(test)] +mod test { + use num_traits::Zero; + + use super::Checksum; + + #[test] + pub fn test_binary_representation() { + assert_eq!(Checksum::from_i32(-1).value(), u32::MAX); + assert_eq!(Checksum::from(u32::MAX).value(), u32::MAX); + assert_eq!(Checksum::from(u32::MAX).bitcast_i32(), -1); + } + + fn deserialize(from: &str) -> Checksum { + serde_json::from_str(from).expect("should deserialize") + } + + #[test] + pub fn test_deserialize() { + assert_eq!(deserialize("0").value(), 0); + assert_eq!(deserialize("-1").value(), u32::MAX); + assert_eq!(deserialize("-1.0").value(), u32::MAX); + + assert_eq!(deserialize("3573495687").value(), 3573495687); + assert_eq!(deserialize("3573495687.0").value(), 3573495687); + assert_eq!(deserialize("-721471609.0").value(), 3573495687); + } + + #[test] + pub fn test_arithmetic() { + assert_eq!(Checksum::from(3) + Checksum::from(7), Checksum::from(10)); + + // Checksums should always wrap around + assert_eq!( + Checksum::from(0xFFFFFFFF) + Checksum::from(1), + Checksum::zero() + ); + assert_eq!( + Checksum::zero() - Checksum::from(1), + Checksum::from(0xFFFFFFFF) + ); + + let mut cs = Checksum::from(0x8FFFFFFF); + cs += Checksum::from(0x80000000); + assert_eq!(cs, Checksum::from(0x0FFFFFFF)); + + cs -= Checksum::from(0x80000001); + assert_eq!(cs, Checksum::from(0x8FFFFFFE)); + } +} diff --git a/crates/core/src/sync/interface.rs b/crates/core/src/sync/interface.rs new file mode 100644 index 0000000..aca5eb9 --- /dev/null +++ b/crates/core/src/sync/interface.rs @@ -0,0 +1,217 @@ +use core::cell::RefCell; +use core::ffi::{c_int, c_void}; + +use alloc::borrow::Cow; +use alloc::boxed::Box; +use alloc::rc::Rc; +use alloc::string::ToString; +use alloc::{string::String, vec::Vec}; +use serde::{Deserialize, Serialize}; +use sqlite::{ResultCode, Value}; +use sqlite_nostd::{self as sqlite, ColumnType}; +use sqlite_nostd::{Connection, Context}; + +use crate::error::SQLiteError; + +use super::streaming_sync::SyncClient; +use super::sync_status::DownloadSyncStatus; + +/// Payload provided by SDKs when requesting a sync iteration. +#[derive(Default, Deserialize)] +pub struct StartSyncStream { + /// Bucket parameters to include in the request when opening a sync stream. + #[serde(default)] + pub parameters: Option>, +} + +/// A request sent from a client SDK to the [SyncClient] with a `powersync_control` invocation. +pub enum SyncControlRequest<'a> { + /// The client requests to start a sync iteration. + /// + /// Earlier iterations are implicitly dropped when receiving this request. + StartSyncStream(StartSyncStream), + /// The client requests to stop the current sync iteration. + StopSyncStream, + /// The client is forwading a sync event to the core extension. + SyncEvent(SyncEvent<'a>), +} + +pub enum SyncEvent<'a> { + /// A synthetic event forwarded to the [SyncClient] after being started. + Initialize, + /// An event requesting the sync client to shut down. + TearDown, + /// Notifies the sync client that a token has been refreshed. + /// + /// In response, we'll stop the current iteration to begin another one with the new token. + DidRefreshToken, + /// Notifies the sync client that the current CRUD upload (for which the client SDK is + /// responsible) has finished. + /// + /// If pending CRUD entries have previously prevented a sync from completing, this even can be + /// used to try again. + UploadFinished, + /// Forward a text line (JSON) received from the sync service. + TextLine { data: &'a str }, + /// Forward a binary line (BSON) received from the sync service. + BinaryLine { data: &'a [u8] }, +} + +/// An instruction sent by the core extension to the SDK. +#[derive(Serialize)] +pub enum Instruction { + LogLine { + severity: LogSeverity, + line: Cow<'static, str>, + }, + /// Update the download status for the ongoing sync iteration. + UpdateSyncStatus { + status: Rc>, + }, + /// Connect to the sync service using the [StreamingSyncRequest] created by the core extension, + /// and then forward received lines via [SyncEvent::TextLine] and [SyncEvent::BinaryLine]. + EstablishSyncStream { request: StreamingSyncRequest }, + FetchCredentials { + /// Whether the credentials currently used have expired. + /// + /// If false, this is a pre-fetch. + did_expire: bool, + }, + // These are defined like this because deserializers in Kotlin can't support either an + // object or a literal value + /// Close the websocket / HTTP stream to the sync service. + CloseSyncStream {}, + /// Flush the file-system if it's non-durable (only applicable to the Dart SDK). + FlushFileSystem {}, + /// Notify that a sync has been completed, prompting client SDKs to clear earlier errors. + DidCompleteSync {}, +} + +#[derive(Serialize)] +pub enum LogSeverity { + DEBUG, + INFO, + WARNING, +} + +#[derive(Serialize)] +pub struct StreamingSyncRequest { + pub buckets: Vec, + pub include_checksum: bool, + pub raw_data: bool, + pub binary_data: bool, + pub client_id: String, + pub parameters: Option>, +} + +#[derive(Serialize)] +pub struct BucketRequest { + pub name: String, + pub after: String, +} + +/// Wrapper around a [SyncClient]. +/// +/// We allocate one instance of this per database (in [register]) - the [SyncClient] has an initial +/// empty state that doesn't consume any resources. +struct SqlController { + client: SyncClient, +} + +pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { + extern "C" fn control( + ctx: *mut sqlite::context, + argc: c_int, + argv: *mut *mut sqlite::value, + ) -> () { + let result = (|| -> Result<(), SQLiteError> { + debug_assert!(!ctx.db_handle().get_autocommit()); + + let controller = unsafe { ctx.user_data().cast::().as_mut() } + .ok_or_else(|| SQLiteError::from(ResultCode::INTERNAL))?; + + let args = sqlite::args!(argc, argv); + let [op, payload] = args else { + return Err(ResultCode::MISUSE.into()); + }; + + if op.value_type() != ColumnType::Text { + return Err(SQLiteError( + ResultCode::MISUSE, + Some("First argument must be a string".to_string()), + )); + } + + let op = op.text(); + let event = match op { + "start" => SyncControlRequest::StartSyncStream({ + if payload.value_type() == ColumnType::Text { + serde_json::from_str(payload.text())? + } else { + StartSyncStream::default() + } + }), + "stop" => SyncControlRequest::StopSyncStream, + "line_text" => SyncControlRequest::SyncEvent(SyncEvent::TextLine { + data: if payload.value_type() == ColumnType::Text { + payload.text() + } else { + return Err(SQLiteError( + ResultCode::MISUSE, + Some("Second argument must be a string".to_string()), + )); + }, + }), + "line_binary" => SyncControlRequest::SyncEvent(SyncEvent::BinaryLine { + data: if payload.value_type() == ColumnType::Blob { + payload.blob() + } else { + return Err(SQLiteError( + ResultCode::MISUSE, + Some("Second argument must be a byte array".to_string()), + )); + }, + }), + "refreshed_token" => SyncControlRequest::SyncEvent(SyncEvent::DidRefreshToken), + "completed_upload" => SyncControlRequest::SyncEvent(SyncEvent::UploadFinished), + _ => { + return Err(SQLiteError( + ResultCode::MISUSE, + Some("Unknown operation".to_string()), + )) + } + }; + + let instructions = controller.client.push_event(event)?; + let formatted = serde_json::to_string(&instructions)?; + ctx.result_text_transient(&formatted); + + Ok(()) + })(); + + if let Err(e) = result { + e.apply_to_ctx("powersync_control", ctx); + } + } + + unsafe extern "C" fn destroy(ptr: *mut c_void) { + drop(Box::from_raw(ptr.cast::())); + } + + let controller = Box::new(SqlController { + client: SyncClient::new(db), + }); + + db.create_function_v2( + "powersync_control", + 2, + sqlite::UTF8 | sqlite::DIRECTONLY, + Some(Box::into_raw(controller).cast()), + Some(control), + None, + None, + Some(destroy), + )?; + + Ok(()) +} diff --git a/crates/core/src/sync/line.rs b/crates/core/src/sync/line.rs new file mode 100644 index 0000000..ab8c199 --- /dev/null +++ b/crates/core/src/sync/line.rs @@ -0,0 +1,308 @@ +use alloc::borrow::Cow; +use alloc::vec::Vec; +use serde::Deserialize; + +use crate::util::{deserialize_optional_string_to_i64, deserialize_string_to_i64}; + +use super::bucket_priority::BucketPriority; +use super::Checksum; + +/// While we would like to always borrow strings for efficiency, that's not consistently possible. +/// With the JSON decoder, borrowing from input data is only possible when the string contains no +/// escape sequences (otherwise, the string is not a direct view of input data and we need an +/// internal copy). +type SyncLineStr<'a> = Cow<'a, str>; + +#[derive(Deserialize, Debug)] + +pub enum SyncLine<'a> { + #[serde(rename = "checkpoint", borrow)] + Checkpoint(Checkpoint<'a>), + #[serde(rename = "checkpoint_diff", borrow)] + CheckpointDiff(CheckpointDiff<'a>), + + #[serde(rename = "checkpoint_complete")] + CheckpointComplete(CheckpointComplete), + #[serde(rename = "partial_checkpoint_complete")] + CheckpointPartiallyComplete(CheckpointPartiallyComplete), + + #[serde(rename = "data", borrow)] + Data(DataLine<'a>), + + #[serde(rename = "token_expires_in")] + KeepAlive(TokenExpiresIn), +} + +#[derive(Deserialize, Debug)] +pub struct Checkpoint<'a> { + #[serde(deserialize_with = "deserialize_string_to_i64")] + pub last_op_id: i64, + #[serde(default)] + #[serde(deserialize_with = "deserialize_optional_string_to_i64")] + pub write_checkpoint: Option, + #[serde(borrow)] + pub buckets: Vec>, +} + +#[derive(Deserialize, Debug)] +pub struct CheckpointDiff<'a> { + #[serde(deserialize_with = "deserialize_string_to_i64")] + pub last_op_id: i64, + #[serde(borrow)] + pub updated_buckets: Vec>, + #[serde(borrow)] + pub removed_buckets: Vec>, + #[serde(default)] + #[serde(deserialize_with = "deserialize_optional_string_to_i64")] + pub write_checkpoint: Option, +} + +#[derive(Deserialize, Debug)] +pub struct CheckpointComplete { + // #[serde(deserialize_with = "deserialize_string_to_i64")] + // pub last_op_id: i64, +} + +#[derive(Deserialize, Debug)] +pub struct CheckpointPartiallyComplete { + // #[serde(deserialize_with = "deserialize_string_to_i64")] + // pub last_op_id: i64, + pub priority: BucketPriority, +} + +#[derive(Deserialize, Debug)] +pub struct BucketChecksum<'a> { + #[serde(borrow)] + pub bucket: SyncLineStr<'a>, + pub checksum: Checksum, + #[serde(default)] + pub priority: Option, + #[serde(default)] + pub count: Option, + // #[serde(default)] + // #[serde(deserialize_with = "deserialize_optional_string_to_i64")] + // pub last_op_id: Option, +} + +#[derive(Deserialize, Debug)] +pub struct DataLine<'a> { + #[serde(borrow)] + pub bucket: SyncLineStr<'a>, + pub data: Vec>, + // #[serde(default)] + // pub has_more: bool, + // #[serde(default, borrow)] + // pub after: Option>, + // #[serde(default, borrow)] + // pub next_after: Option>, +} + +#[derive(Deserialize, Debug)] +pub struct OplogEntry<'a> { + pub checksum: Checksum, + #[serde(deserialize_with = "deserialize_string_to_i64")] + pub op_id: i64, + pub op: OpType, + #[serde(default, borrow)] + pub object_id: Option>, + #[serde(default, borrow)] + pub object_type: Option>, + #[serde(default, borrow)] + pub subkey: Option>, + #[serde(default, borrow)] + pub data: Option>, +} + +#[derive(Debug)] +pub enum OplogData<'a> { + /// A string encoding a well-formed JSON object representing values of the row. + Json { data: Cow<'a, str> }, + // BsonDocument { data: Cow<'a, [u8]> }, +} + +#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub enum OpType { + CLEAR, + MOVE, + PUT, + REMOVE, +} + +#[repr(transparent)] +#[derive(Deserialize, Debug, Clone, Copy)] +pub struct TokenExpiresIn(pub i32); + +impl TokenExpiresIn { + pub fn is_expired(self) -> bool { + self.0 <= 0 + } + + pub fn should_prefetch(self) -> bool { + !self.is_expired() && self.0 <= 30 + } +} + +impl<'a, 'de: 'a> Deserialize<'de> for OplogData<'a> { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + // For now, we will always get oplog data as a string. In the future, there may be the + // option of the sync service sending BSON-encoded data lines too, but that's not relevant + // for now. + return Ok(OplogData::Json { + data: Deserialize::deserialize(deserializer)?, + }); + } +} + +#[cfg(test)] +mod tests { + use core::assert_matches::assert_matches; + + use super::*; + + fn deserialize(source: &str) -> SyncLine { + serde_json::from_str(source).expect("Should have deserialized") + } + + #[test] + fn parse_token_expires_in() { + assert_matches!( + deserialize(r#"{"token_expires_in": 123}"#), + SyncLine::KeepAlive(TokenExpiresIn(123)) + ); + } + + #[test] + fn parse_checkpoint() { + assert_matches!( + deserialize(r#"{"checkpoint": {"last_op_id": "10", "buckets": []}}"#), + SyncLine::Checkpoint(Checkpoint { + last_op_id: 10, + write_checkpoint: None, + buckets: _, + }) + ); + + let SyncLine::Checkpoint(checkpoint) = deserialize( + r#"{"checkpoint": {"last_op_id": "10", "buckets": [{"bucket": "a", "checksum": 10}]}}"#, + ) else { + panic!("Expected checkpoint"); + }; + + assert_eq!(checkpoint.buckets.len(), 1); + let bucket = &checkpoint.buckets[0]; + assert_eq!(bucket.bucket, "a"); + assert_eq!(bucket.checksum, 10u32.into()); + assert_eq!(bucket.priority, None); + + let SyncLine::Checkpoint(checkpoint) = deserialize( + r#"{"checkpoint": {"last_op_id": "10", "buckets": [{"bucket": "a", "priority": 1, "checksum": 10}]}}"#, + ) else { + panic!("Expected checkpoint"); + }; + + assert_eq!(checkpoint.buckets.len(), 1); + let bucket = &checkpoint.buckets[0]; + assert_eq!(bucket.bucket, "a"); + assert_eq!(bucket.checksum, 10u32.into()); + assert_eq!(bucket.priority, Some(BucketPriority { number: 1 })); + + assert_matches!( + deserialize( + r#"{"checkpoint":{"write_checkpoint":null,"last_op_id":"1","buckets":[{"bucket":"a","checksum":0,"priority":3,"count":1}]}}"# + ), + SyncLine::Checkpoint(Checkpoint { + last_op_id: 1, + write_checkpoint: None, + buckets: _, + }) + ); + } + + #[test] + fn parse_checkpoint_diff() { + let SyncLine::CheckpointDiff(diff) = deserialize( + r#"{"checkpoint_diff": {"last_op_id": "10", "buckets": [], "updated_buckets": [], "removed_buckets": [], "write_checkpoint": null}}"#, + ) else { + panic!("Expected checkpoint diff") + }; + + assert_eq!(diff.updated_buckets.len(), 0); + assert_eq!(diff.removed_buckets.len(), 0); + } + + #[test] + fn parse_checkpoint_diff_escape() { + let SyncLine::CheckpointDiff(diff) = deserialize( + r#"{"checkpoint_diff": {"last_op_id": "10", "buckets": [], "updated_buckets": [], "removed_buckets": ["foo\""], "write_checkpoint": null}}"#, + ) else { + panic!("Expected checkpoint diff") + }; + + assert_eq!(diff.removed_buckets[0], "foo\""); + } + + #[test] + fn parse_checkpoint_diff_no_write_checkpoint() { + let SyncLine::CheckpointDiff(_diff) = deserialize( + r#"{"checkpoint_diff":{"last_op_id":"12","updated_buckets":[{"bucket":"a","count":12,"checksum":0,"priority":3}],"removed_buckets":[]}}"#, + ) else { + panic!("Expected checkpoint diff") + }; + } + + #[test] + fn parse_checkpoint_complete() { + assert_matches!( + deserialize(r#"{"checkpoint_complete": {"last_op_id": "10"}}"#), + SyncLine::CheckpointComplete(CheckpointComplete { + // last_op_id: 10 + }) + ); + } + + #[test] + fn parse_checkpoint_partially_complete() { + assert_matches!( + deserialize(r#"{"partial_checkpoint_complete": {"last_op_id": "10", "priority": 1}}"#), + SyncLine::CheckpointPartiallyComplete(CheckpointPartiallyComplete { + //last_op_id: 10, + priority: BucketPriority { number: 1 } + }) + ); + } + + #[test] + fn parse_data() { + let SyncLine::Data(data) = deserialize( + r#"{"data": { + "bucket": "bkt", + "data": [{"checksum":10,"op_id":"1","object_id":"test","object_type":"users","op":"PUT","subkey":null,"data":"{\"name\":\"user 0\",\"email\":\"0@example.org\"}"}], + "after": null, + "next_after": null} + }"#, + ) else { + panic!("Expected data line") + }; + + assert_eq!(data.bucket, "bkt"); + + assert_eq!(data.data.len(), 1); + let entry = &data.data[0]; + assert_eq!(entry.checksum, 10u32.into()); + assert_matches!( + &data.data[0], + OplogEntry { + checksum: _, + op_id: 1, + object_id: Some(_), + object_type: Some(_), + op: OpType::PUT, + subkey: None, + data: _, + } + ); + } +} diff --git a/crates/core/src/sync/mod.rs b/crates/core/src/sync/mod.rs new file mode 100644 index 0000000..fb4f02c --- /dev/null +++ b/crates/core/src/sync/mod.rs @@ -0,0 +1,18 @@ +use sqlite_nostd::{self as sqlite, ResultCode}; + +mod bucket_priority; +pub mod checkpoint; +mod checksum; +mod interface; +pub mod line; +pub mod operations; +pub mod storage_adapter; +mod streaming_sync; +mod sync_status; + +pub use bucket_priority::BucketPriority; +pub use checksum::Checksum; + +pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { + interface::register(db) +} diff --git a/crates/core/src/sync/operations.rs b/crates/core/src/sync/operations.rs new file mode 100644 index 0000000..29c23f3 --- /dev/null +++ b/crates/core/src/sync/operations.rs @@ -0,0 +1,210 @@ +use alloc::format; +use alloc::string::String; +use num_traits::Zero; +use sqlite_nostd::Connection; +use sqlite_nostd::{self as sqlite, ResultCode}; + +use crate::{ + error::{PSResult, SQLiteError}, + ext::SafeManagedStmt, +}; + +use super::line::OplogData; +use super::Checksum; +use super::{ + line::{DataLine, OpType}, + storage_adapter::{BucketInfo, StorageAdapter}, +}; + +pub fn insert_bucket_operations( + adapter: &StorageAdapter, + data: &DataLine, +) -> Result<(), SQLiteError> { + let db = adapter.db; + let BucketInfo { + id: bucket_id, + last_applied_op, + } = adapter.lookup_bucket(&*data.bucket)?; + + // This is an optimization for initial sync - we can avoid persisting individual REMOVE + // operations when last_applied_op = 0. + // We do still need to do the "supersede_statement" step for this case, since a REMOVE + // operation can supersede another PUT operation we're syncing at the same time. + let mut is_empty = last_applied_op == 0; + + // Statement to supersede (replace) operations with the same key. + // language=SQLite + let supersede_statement = db.prepare_v2( + "\ +DELETE FROM ps_oplog + WHERE unlikely(ps_oplog.bucket = ?1) + AND ps_oplog.key = ?2 +RETURNING op_id, hash", + )?; + supersede_statement.bind_int64(1, bucket_id)?; + + // language=SQLite + let insert_statement = db.prepare_v2("\ +INSERT INTO ps_oplog(bucket, op_id, key, row_type, row_id, data, hash) VALUES (?, ?, ?, ?, ?, ?, ?)")?; + insert_statement.bind_int64(1, bucket_id)?; + + let updated_row_statement = db.prepare_v2( + "\ +INSERT OR IGNORE INTO ps_updated_rows(row_type, row_id) VALUES(?1, ?2)", + )?; + + let mut last_op: Option = None; + let mut add_checksum = Checksum::zero(); + let mut op_checksum = Checksum::zero(); + let mut added_ops: i32 = 0; + + for line in &data.data { + let op_id = line.op_id; + let op = line.op; + let object_type = line.object_type.as_ref(); + let object_id = line.object_id.as_ref(); + let checksum = line.checksum; + let op_data = line.data.as_ref(); + + last_op = Some(op_id); + added_ops += 1; + + if op == OpType::PUT || op == OpType::REMOVE { + let key: String; + if let (Some(object_type), Some(object_id)) = (object_type, object_id) { + let subkey = line.subkey.as_ref().map(|i| &**i).unwrap_or("null"); + key = format!("{}/{}/{}", &object_type, &object_id, subkey); + } else { + key = String::from(""); + } + + supersede_statement.bind_text(2, &key, sqlite::Destructor::STATIC)?; + + let mut superseded = false; + + while supersede_statement.step()? == ResultCode::ROW { + // Superseded (deleted) a previous operation, add the checksum + let supersede_checksum = Checksum::from_i32(supersede_statement.column_int(1)); + add_checksum += supersede_checksum; + op_checksum -= supersede_checksum; + + // Superseded an operation, only skip if the bucket was empty + // Previously this checked "superseded_op <= last_applied_op". + // However, that would not account for a case where a previous + // PUT operation superseded the original PUT operation in this + // same batch, in which case superseded_op is not accurate for this. + if !is_empty { + superseded = true; + } + } + supersede_statement.reset()?; + + if op == OpType::REMOVE { + let should_skip_remove = !superseded; + + add_checksum += checksum; + + if !should_skip_remove { + if let (Some(object_type), Some(object_id)) = (object_type, object_id) { + updated_row_statement.bind_text( + 1, + object_type, + sqlite::Destructor::STATIC, + )?; + updated_row_statement.bind_text( + 2, + object_id, + sqlite::Destructor::STATIC, + )?; + updated_row_statement.exec()?; + } + } + + continue; + } + + insert_statement.bind_int64(2, op_id)?; + if key != "" { + insert_statement.bind_text(3, &key, sqlite::Destructor::STATIC)?; + } else { + insert_statement.bind_null(3)?; + } + + if let (Some(object_type), Some(object_id)) = (object_type, object_id) { + insert_statement.bind_text(4, object_type, sqlite::Destructor::STATIC)?; + insert_statement.bind_text(5, object_id, sqlite::Destructor::STATIC)?; + } else { + insert_statement.bind_null(4)?; + insert_statement.bind_null(5)?; + } + if let Some(data) = op_data { + let OplogData::Json { data } = data; + + insert_statement.bind_text(6, data, sqlite::Destructor::STATIC)?; + } else { + insert_statement.bind_null(6)?; + } + + insert_statement.bind_int(7, checksum.bitcast_i32())?; + insert_statement.exec()?; + + op_checksum += checksum; + } else if op == OpType::MOVE { + add_checksum += checksum; + } else if op == OpType::CLEAR { + // Any remaining PUT operations should get an implicit REMOVE + // language=SQLite + let clear_statement1 = db + .prepare_v2( + "INSERT OR IGNORE INTO ps_updated_rows(row_type, row_id) +SELECT row_type, row_id +FROM ps_oplog +WHERE bucket = ?1", + ) + .into_db_result(db)?; + clear_statement1.bind_int64(1, bucket_id)?; + clear_statement1.exec()?; + + let clear_statement2 = db + .prepare_v2("DELETE FROM ps_oplog WHERE bucket = ?1") + .into_db_result(db)?; + clear_statement2.bind_int64(1, bucket_id)?; + clear_statement2.exec()?; + + // And we need to re-apply all of those. + // We also replace the checksum with the checksum of the CLEAR op. + // language=SQLite + let clear_statement2 = db.prepare_v2( + "UPDATE ps_buckets SET last_applied_op = 0, add_checksum = ?1, op_checksum = 0 WHERE id = ?2", + )?; + clear_statement2.bind_int64(2, bucket_id)?; + clear_statement2.bind_int(1, checksum.bitcast_i32())?; + clear_statement2.exec()?; + + add_checksum = Checksum::zero(); + is_empty = true; + op_checksum = Checksum::zero(); + } + } + + if let Some(last_op) = &last_op { + // language=SQLite + let statement = db.prepare_v2( + "UPDATE ps_buckets + SET last_op = ?2, + add_checksum = (add_checksum + ?3) & 0xffffffff, + op_checksum = (op_checksum + ?4) & 0xffffffff, + count_since_last = count_since_last + ?5 + WHERE id = ?1", + )?; + statement.bind_int64(1, bucket_id)?; + statement.bind_int64(2, *last_op)?; + statement.bind_int(3, add_checksum.bitcast_i32())?; + statement.bind_int(4, op_checksum.bitcast_i32())?; + statement.bind_int(5, added_ops)?; + + statement.exec()?; + } + + Ok(()) +} diff --git a/crates/core/src/sync/storage_adapter.rs b/crates/core/src/sync/storage_adapter.rs new file mode 100644 index 0000000..ed71b79 --- /dev/null +++ b/crates/core/src/sync/storage_adapter.rs @@ -0,0 +1,314 @@ +use core::{assert_matches::debug_assert_matches, fmt::Display}; + +use alloc::{string::ToString, vec::Vec}; +use serde::Serialize; +use sqlite_nostd::{self as sqlite, Connection, ManagedStmt, ResultCode}; +use streaming_iterator::StreamingIterator; + +use crate::{ + error::SQLiteError, + ext::SafeManagedStmt, + operations::delete_bucket, + sync::checkpoint::{validate_checkpoint, ChecksumMismatch}, + sync_local::{PartialSyncOperation, SyncOperation}, +}; + +use super::{ + bucket_priority::BucketPriority, interface::BucketRequest, streaming_sync::OwnedCheckpoint, + sync_status::Timestamp, +}; + +/// An adapter for storing sync state. +/// +/// This is used to encapsulate some SQL queries used for the sync implementation, making the code +/// in `streaming_sync.rs` easier to read. It also allows caching some prepared statements that are +/// used frequently as an optimization, but we're not taking advantage of that yet. +pub struct StorageAdapter { + pub db: *mut sqlite::sqlite3, + progress_stmt: ManagedStmt, + time_stmt: ManagedStmt, +} + +impl StorageAdapter { + pub fn new(db: *mut sqlite::sqlite3) -> Result { + // language=SQLite + let progress = + db.prepare_v2("SELECT name, count_at_last, count_since_last FROM ps_buckets")?; + + // language=SQLite + let time = db.prepare_v2("SELECT unixepoch()")?; + + Ok(Self { + db, + progress_stmt: progress, + time_stmt: time, + }) + } + + pub fn collect_bucket_requests(&self) -> Result, SQLiteError> { + // language=SQLite + let statement = self.db.prepare_v2( + "SELECT name, last_op FROM ps_buckets WHERE pending_delete = 0 AND name != '$local'", + )?; + + let mut requests = Vec::::new(); + + while statement.step()? == ResultCode::ROW { + let bucket_name = statement.column_text(0)?.to_string(); + let last_op = statement.column_int64(1); + + requests.push(BucketRequest { + name: bucket_name.clone(), + after: last_op.to_string(), + }); + } + + Ok(requests) + } + + pub fn delete_buckets<'a>( + &self, + buckets: impl IntoIterator, + ) -> Result<(), SQLiteError> { + for bucket in buckets { + // TODO: This is a neat opportunity to create the statements here and cache them + delete_bucket(self.db, bucket)?; + } + + Ok(()) + } + + pub fn local_progress( + &self, + ) -> Result< + impl StreamingIterator>, + ResultCode, + > { + self.progress_stmt.reset()?; + + fn step(stmt: &ManagedStmt) -> Result, ResultCode> { + if stmt.step()? == ResultCode::ROW { + let bucket = stmt.column_text(0)?; + let count_at_last = stmt.column_int64(1); + let count_since_last = stmt.column_int64(2); + + return Ok(Some(PersistedBucketProgress { + bucket, + count_at_last, + count_since_last, + })); + } + + Ok(None) + } + + Ok(streaming_iterator::from_fn(|| { + match step(&self.progress_stmt) { + Err(e) => Some(Err(e)), + Ok(Some(other)) => Some(Ok(other)), + Ok(None) => None, + } + })) + } + + pub fn reset_progress(&self) -> Result<(), ResultCode> { + self.db + .exec_safe("UPDATE ps_buckets SET count_since_last = 0, count_at_last = 0;")?; + Ok(()) + } + + pub fn lookup_bucket(&self, bucket: &str) -> Result { + // We do an ON CONFLICT UPDATE simply so that the RETURNING bit works for existing rows. + // We can consider splitting this into separate SELECT and INSERT statements. + // language=SQLite + let bucket_statement = self.db.prepare_v2( + "INSERT INTO ps_buckets(name) + VALUES(?) + ON CONFLICT DO UPDATE + SET last_applied_op = last_applied_op + RETURNING id, last_applied_op", + )?; + bucket_statement.bind_text(1, bucket, sqlite::Destructor::STATIC)?; + let res = bucket_statement.step()?; + debug_assert_matches!(res, ResultCode::ROW); + + let bucket_id = bucket_statement.column_int64(0); + let last_applied_op = bucket_statement.column_int64(1); + + return Ok(BucketInfo { + id: bucket_id, + last_applied_op, + }); + } + + pub fn sync_local( + &self, + checkpoint: &OwnedCheckpoint, + priority: Option, + ) -> Result { + let mismatched_checksums = + validate_checkpoint(checkpoint.buckets.values(), priority, self.db)?; + + if !mismatched_checksums.is_empty() { + self.delete_buckets(mismatched_checksums.iter().map(|i| i.bucket_name.as_str()))?; + + return Ok(SyncLocalResult::ChecksumFailure(CheckpointResult { + failed_buckets: mismatched_checksums, + })); + } + + let update_bucket = self + .db + .prepare_v2("UPDATE ps_buckets SET last_op = ? WHERE name = ?")?; + + for bucket in checkpoint.buckets.values() { + if bucket.is_in_priority(priority) { + update_bucket.bind_int64(1, checkpoint.last_op_id)?; + update_bucket.bind_text(2, &bucket.bucket, sqlite::Destructor::STATIC)?; + update_bucket.exec()?; + } + } + + if let (None, Some(write_checkpoint)) = (&priority, &checkpoint.write_checkpoint) { + update_bucket.bind_int64(1, *write_checkpoint)?; + update_bucket.bind_text(2, "$local", sqlite::Destructor::STATIC)?; + update_bucket.exec()?; + } + + #[derive(Serialize)] + struct PartialArgs<'a> { + priority: BucketPriority, + buckets: Vec<&'a str>, + } + + let sync_result = match priority { + None => SyncOperation::new(self.db, None).apply(), + Some(priority) => { + let args = PartialArgs { + priority, + buckets: checkpoint + .buckets + .values() + .filter_map(|item| { + if item.is_in_priority(Some(priority)) { + Some(item.bucket.as_str()) + } else { + None + } + }) + .collect(), + }; + + // TODO: Avoid this serialization, it's currently used to bind JSON SQL parameters. + let serialized_args = serde_json::to_string(&args)?; + SyncOperation::new( + self.db, + Some(PartialSyncOperation { + priority, + args: &serialized_args, + }), + ) + .apply() + } + }?; + + if sync_result == 1 { + if priority.is_none() { + // Reset progress counters. We only do this for a complete sync, as we want a + // download progress to always cover a complete checkpoint instead of resetting for + // partial completions. + let update = self.db.prepare_v2( + "UPDATE ps_buckets SET count_since_last = 0, count_at_last = ? WHERE name = ?", + )?; + + for bucket in checkpoint.buckets.values() { + if let Some(count) = bucket.count { + update.bind_int64(1, count)?; + update.bind_text(2, bucket.bucket.as_str(), sqlite::Destructor::STATIC)?; + + update.exec()?; + update.reset()?; + } + } + } + + Ok(SyncLocalResult::ChangesApplied) + } else { + Ok(SyncLocalResult::PendingLocalChanges) + } + } + + pub fn now(&self) -> Result { + self.time_stmt.reset()?; + self.time_stmt.step()?; + + Ok(Timestamp(self.time_stmt.column_int64(0))) + } +} + +pub struct BucketInfo { + pub id: i64, + pub last_applied_op: i64, +} + +pub struct CheckpointResult { + failed_buckets: Vec, +} + +impl CheckpointResult { + pub fn is_valid(&self) -> bool { + self.failed_buckets.is_empty() + } +} + +impl Display for CheckpointResult { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + if self.is_valid() { + write!(f, "Valid checkpoint result") + } else { + write!(f, "Checksums didn't match, failed for: ")?; + for (i, item) in self.failed_buckets.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + + item.fmt(f)?; + } + + Ok(()) + } + } +} + +impl Display for ChecksumMismatch { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let actual = self.actual_add_checksum + self.actual_op_checksum; + write!( + f, + "{} (expected {}, got {} = {} (op) + {} (add))", + self.bucket_name, + self.expected_checksum, + actual, + self.actual_op_checksum, + self.actual_add_checksum + ) + } +} + +pub enum SyncLocalResult { + /// Changes could not be applied due to a checksum mismatch. + ChecksumFailure(CheckpointResult), + /// Changes could not be applied because they would break consistency - we need to wait for + /// pending local CRUD data to be uploaded and acknowledged in a write checkpoint. + PendingLocalChanges, + /// The checkpoint has been applied and changes have been published. + ChangesApplied, +} + +/// Information about the amount of operations a bucket had at the last checkpoint and how many +/// operations have been inserted in the meantime. +pub struct PersistedBucketProgress<'a> { + pub bucket: &'a str, + pub count_at_last: i64, + pub count_since_last: i64, +} diff --git a/crates/core/src/sync/streaming_sync.rs b/crates/core/src/sync/streaming_sync.rs new file mode 100644 index 0000000..d5f2f51 --- /dev/null +++ b/crates/core/src/sync/streaming_sync.rs @@ -0,0 +1,562 @@ +use core::{ + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll, Waker}, +}; + +use alloc::{ + boxed::Box, + collections::{btree_map::BTreeMap, btree_set::BTreeSet}, + format, + string::{String, ToString}, + vec::Vec, +}; +use futures_lite::FutureExt; + +use crate::{bson, error::SQLiteError, kv::client_id, sync::checkpoint::OwnedBucketChecksum}; +use sqlite_nostd::{self as sqlite, ResultCode}; + +use super::{ + interface::{Instruction, LogSeverity, StreamingSyncRequest, SyncControlRequest, SyncEvent}, + line::{Checkpoint, CheckpointDiff, SyncLine}, + operations::insert_bucket_operations, + storage_adapter::{StorageAdapter, SyncLocalResult}, + sync_status::{SyncDownloadProgress, SyncProgressFromCheckpoint, SyncStatusContainer}, +}; + +/// The sync client implementation, responsible for parsing lines received by the sync service and +/// persisting them to the database. +/// +/// The client consumes no resources and prepares no statements until a sync iteration is +/// initialized. +pub struct SyncClient { + db: *mut sqlite::sqlite3, + /// The current [ClientState] (essentially an optional [StreamingSyncIteration]). + state: ClientState, +} + +impl SyncClient { + pub fn new(db: *mut sqlite::sqlite3) -> Self { + Self { + db, + state: ClientState::Idle, + } + } + + pub fn push_event<'a>( + &mut self, + event: SyncControlRequest<'a>, + ) -> Result, SQLiteError> { + match event { + SyncControlRequest::StartSyncStream(options) => { + self.state.tear_down()?; + + let mut handle = SyncIterationHandle::new(self.db, options.parameters)?; + let instructions = handle.initialize()?; + self.state = ClientState::IterationActive(handle); + + Ok(instructions) + } + SyncControlRequest::SyncEvent(sync_event) => { + let mut active = ActiveEvent::new(sync_event); + + let ClientState::IterationActive(handle) = &mut self.state else { + return Err(SQLiteError( + ResultCode::MISUSE, + Some("No iteration is active".to_string()), + )); + }; + + match handle.run(&mut active) { + Err(e) => { + self.state = ClientState::Idle; + return Err(e); + } + Ok(done) => { + if done { + self.state = ClientState::Idle; + } + } + }; + + Ok(active.instructions) + } + SyncControlRequest::StopSyncStream => self.state.tear_down(), + } + } +} + +enum ClientState { + /// No sync iteration is currently active. + Idle, + /// A sync iteration has begun on the database. + IterationActive(SyncIterationHandle), +} + +impl ClientState { + fn tear_down(&mut self) -> Result, SQLiteError> { + let mut event = ActiveEvent::new(SyncEvent::TearDown); + + if let ClientState::IterationActive(old) = self { + old.run(&mut event)?; + }; + + *self = ClientState::Idle; + Ok(event.instructions) + } +} + +/// A handle that allows progressing a [StreamingSyncIteration]. +/// +/// The sync itertion itself is implemented as an `async` function, as this allows us to treat it +/// as a coroutine that preserves internal state between multiple `powersync_control` invocations. +/// At each invocation, the future is polled once (and gets access to context that allows it to +/// render [Instruction]s to return from the function). +struct SyncIterationHandle { + future: Pin>>>, +} + +impl SyncIterationHandle { + /// Creates a new sync iteration in a pending state by preparing statements for + /// [StorageAdapter] and setting up the initial downloading state for [StorageAdapter] . + fn new( + db: *mut sqlite::sqlite3, + parameters: Option>, + ) -> Result { + let runner = StreamingSyncIteration { + db, + parameters, + adapter: StorageAdapter::new(db)?, + status: SyncStatusContainer::new(), + }; + let future = runner.run().boxed_local(); + + Ok(Self { future }) + } + + /// Forwards a [SyncEvent::Initialize] to the current sync iteration, returning the initial + /// instructions generated. + fn initialize(&mut self) -> Result, SQLiteError> { + let mut event = ActiveEvent::new(SyncEvent::Initialize); + let result = self.run(&mut event)?; + assert!(!result, "Stream client aborted initialization"); + + Ok(event.instructions) + } + + fn run(&mut self, active: &mut ActiveEvent) -> Result { + // Using a noop waker because the only event thing StreamingSyncIteration::run polls on is + // the next incoming sync event. + let waker = unsafe { + Waker::new( + active as *const ActiveEvent as *const (), + Waker::noop().vtable(), + ) + }; + let mut context = Context::from_waker(&waker); + + Ok( + if let Poll::Ready(result) = self.future.poll(&mut context) { + result?; + + active.instructions.push(Instruction::CloseSyncStream {}); + true + } else { + false + }, + ) + } +} + +/// A [SyncEvent] currently being handled by a [StreamingSyncIteration]. +struct ActiveEvent<'a> { + handled: bool, + /// The event to handle + event: SyncEvent<'a>, + /// Instructions to forward to the client when the `powersync_control` invocation completes. + instructions: Vec, +} + +impl<'a> ActiveEvent<'a> { + pub fn new(event: SyncEvent<'a>) -> Self { + Self { + handled: false, + event, + instructions: Vec::new(), + } + } +} + +struct StreamingSyncIteration { + db: *mut sqlite::sqlite3, + adapter: StorageAdapter, + parameters: Option>, + status: SyncStatusContainer, +} + +impl StreamingSyncIteration { + fn receive_event<'a>() -> impl Future> { + struct Wait<'a> { + a: PhantomData<&'a StreamingSyncIteration>, + } + + impl<'a> Future for Wait<'a> { + type Output = &'a mut ActiveEvent<'a>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let context = cx.waker().data().cast_mut() as *mut ActiveEvent; + let context = unsafe { &mut *context }; + + if context.handled { + Poll::Pending + } else { + context.handled = true; + Poll::Ready(context) + } + } + } + + Wait { a: PhantomData } + } + + async fn run(mut self) -> Result<(), SQLiteError> { + let mut target = SyncTarget::BeforeCheckpoint(self.prepare_request().await?); + + // A checkpoint that has been fully received and validated, but couldn't be applied due to + // pending local data. We will retry applying this checkpoint when the client SDK informs us + // that it has finished uploading changes. + let mut validated_but_not_applied = None::; + + loop { + let event = Self::receive_event().await; + + let line: SyncLine = match event.event { + SyncEvent::Initialize { .. } => { + panic!("Initialize should only be emited once") + } + SyncEvent::TearDown => { + self.status + .update(|s| s.disconnect(), &mut event.instructions); + break; + } + SyncEvent::TextLine { data } => serde_json::from_str(data)?, + SyncEvent::BinaryLine { data } => bson::from_bytes(data)?, + SyncEvent::UploadFinished => { + if let Some(checkpoint) = validated_but_not_applied.take() { + let result = self.adapter.sync_local(&checkpoint, None)?; + + match result { + SyncLocalResult::ChangesApplied => { + event.instructions.push(Instruction::LogLine { + severity: LogSeverity::DEBUG, + line: "Applied pending checkpoint after completed upload" + .into(), + }); + + self.handle_checkpoint_applied(event)?; + } + _ => { + event.instructions.push(Instruction::LogLine { + severity: LogSeverity::WARNING, + line: "Could not apply pending checkpoint even after completed upload" + .into(), + }); + } + } + } + + continue; + } + SyncEvent::DidRefreshToken => { + // Break so that the client SDK starts another iteration. + break; + } + }; + + self.status.update_only(|s| s.mark_connected()); + + match line { + SyncLine::Checkpoint(checkpoint) => { + validated_but_not_applied = None; + let to_delete = target.track_checkpoint(&checkpoint); + + self.adapter + .delete_buckets(to_delete.iter().map(|b| b.as_str()))?; + let progress = self.load_progress(target.target_checkpoint().unwrap())?; + self.status.update( + |s| s.start_tracking_checkpoint(progress), + &mut event.instructions, + ); + } + SyncLine::CheckpointDiff(diff) => { + let Some(target) = target.target_checkpoint_mut() else { + return Err(SQLiteError( + ResultCode::ABORT, + Some( + "Received checkpoint_diff without previous checkpoint".to_string(), + ), + )); + }; + + target.apply_diff(&diff); + validated_but_not_applied = None; + self.adapter + .delete_buckets(diff.removed_buckets.iter().map(|i| &**i))?; + + let progress = self.load_progress(target)?; + self.status.update( + |s| s.start_tracking_checkpoint(progress), + &mut event.instructions, + ); + } + SyncLine::CheckpointComplete(_) => { + let Some(target) = target.target_checkpoint_mut() else { + return Err(SQLiteError( + ResultCode::ABORT, + Some( + "Received checkpoint complete without previous checkpoint" + .to_string(), + ), + )); + }; + let result = self.adapter.sync_local(target, None)?; + + match result { + SyncLocalResult::ChecksumFailure(checkpoint_result) => { + // This means checksums failed. Start again with a new checkpoint. + // TODO: better back-off + // await new Promise((resolve) => setTimeout(resolve, 50)); + event.instructions.push(Instruction::LogLine { + severity: LogSeverity::WARNING, + line: format!("Could not apply checkpoint, {checkpoint_result}") + .into(), + }); + break; + } + SyncLocalResult::PendingLocalChanges => { + event.instructions.push(Instruction::LogLine { + severity: LogSeverity::INFO, + line: "Could not apply checkpoint due to local data. Will retry at completed upload or next checkpoint.".into(), + }); + + validated_but_not_applied = Some(target.clone()); + } + SyncLocalResult::ChangesApplied => { + event.instructions.push(Instruction::LogLine { + severity: LogSeverity::DEBUG, + line: "Validated and applied checkpoint".into(), + }); + event.instructions.push(Instruction::FlushFileSystem {}); + self.handle_checkpoint_applied(event)?; + } + } + } + SyncLine::CheckpointPartiallyComplete(complete) => { + let priority = complete.priority; + let Some(target) = target.target_checkpoint_mut() else { + return Err(SQLiteError( + ResultCode::ABORT, + Some( + "Received checkpoint complete without previous checkpoint" + .to_string(), + ), + )); + }; + let result = self.adapter.sync_local(target, Some(priority))?; + + match result { + SyncLocalResult::ChecksumFailure(checkpoint_result) => { + // This means checksums failed. Start again with a new checkpoint. + // TODO: better back-off + // await new Promise((resolve) => setTimeout(resolve, 50)); + event.instructions.push(Instruction::LogLine { + severity: LogSeverity::WARNING, + line: format!( + "Could not apply partial checkpoint, {checkpoint_result}" + ) + .into(), + }); + break; + } + SyncLocalResult::PendingLocalChanges => { + // If we have pending uploads, we can't complete new checkpoints outside + // of priority 0. We'll resolve this for a complete checkpoint later. + } + SyncLocalResult::ChangesApplied => { + let now = self.adapter.now()?; + event.instructions.push(Instruction::FlushFileSystem {}); + self.status.update( + |status| { + status.partial_checkpoint_complete(priority, now); + }, + &mut event.instructions, + ); + } + } + } + SyncLine::Data(data_line) => { + self.status + .update(|s| s.track_line(&data_line), &mut event.instructions); + insert_bucket_operations(&self.adapter, &data_line)?; + } + SyncLine::KeepAlive(token) => { + if token.is_expired() { + // Token expired already - stop the connection immediately. + event + .instructions + .push(Instruction::FetchCredentials { did_expire: true }); + break; + } else if token.should_prefetch() { + event + .instructions + .push(Instruction::FetchCredentials { did_expire: false }); + } + } + } + + self.status.emit_changes(&mut event.instructions); + } + + Ok(()) + } + + fn load_progress( + &self, + checkpoint: &OwnedCheckpoint, + ) -> Result { + let local_progress = self.adapter.local_progress()?; + let SyncProgressFromCheckpoint { + progress, + needs_counter_reset, + } = SyncDownloadProgress::for_checkpoint(checkpoint, local_progress)?; + + if needs_counter_reset { + self.adapter.reset_progress()?; + } + + Ok(progress) + } + + /// Prepares a sync iteration by handling the initial [SyncEvent::Initialize]. + /// + /// This prepares a [StreamingSyncRequest] by fetching local sync state and the requested bucket + /// parameters. + async fn prepare_request(&mut self) -> Result, SQLiteError> { + let event = Self::receive_event().await; + let SyncEvent::Initialize = event.event else { + return Err(SQLiteError::from(ResultCode::MISUSE)); + }; + + self.status + .update(|s| s.start_connecting(), &mut event.instructions); + + let requests = self.adapter.collect_bucket_requests()?; + let local_bucket_names: Vec = requests.iter().map(|s| s.name.clone()).collect(); + let request = StreamingSyncRequest { + buckets: requests, + include_checksum: true, + raw_data: true, + binary_data: true, + client_id: client_id(self.db)?, + parameters: self.parameters.take(), + }; + + event + .instructions + .push(Instruction::EstablishSyncStream { request }); + Ok(local_bucket_names) + } + + fn handle_checkpoint_applied(&mut self, event: &mut ActiveEvent) -> Result<(), ResultCode> { + event.instructions.push(Instruction::DidCompleteSync {}); + + let now = self.adapter.now()?; + self.status.update( + |status| status.applied_checkpoint(now), + &mut event.instructions, + ); + + Ok(()) + } +} + +#[derive(Debug)] +enum SyncTarget { + /// We've received a checkpoint line towards the given checkpoint. The tracked checkpoint is + /// updated for subsequent checkpoint or checkpoint_diff lines. + Tracking(OwnedCheckpoint), + /// We have not received a checkpoint message yet. We still keep a list of local buckets around + /// so that we know which ones to delete depending on the first checkpoint message. + BeforeCheckpoint(Vec), +} + +impl SyncTarget { + fn target_checkpoint(&self) -> Option<&OwnedCheckpoint> { + match self { + Self::Tracking(cp) => Some(cp), + _ => None, + } + } + + fn target_checkpoint_mut(&mut self) -> Option<&mut OwnedCheckpoint> { + match self { + Self::Tracking(cp) => Some(cp), + _ => None, + } + } + + /// Starts tracking the received `Checkpoint`. + /// + /// This updates the internal state and returns a set of buckets to delete because they've been + /// tracked locally but not in the new checkpoint. + fn track_checkpoint<'a>(&mut self, checkpoint: &Checkpoint<'a>) -> BTreeSet { + let mut to_delete: BTreeSet = match &self { + SyncTarget::Tracking(checkpoint) => checkpoint.buckets.keys().cloned().collect(), + SyncTarget::BeforeCheckpoint(buckets) => buckets.iter().cloned().collect(), + }; + + let mut buckets = BTreeMap::::new(); + for bucket in &checkpoint.buckets { + buckets.insert(bucket.bucket.to_string(), OwnedBucketChecksum::from(bucket)); + to_delete.remove(&*bucket.bucket); + } + + *self = SyncTarget::Tracking(OwnedCheckpoint::from_checkpoint(checkpoint, buckets)); + to_delete + } +} + +#[derive(Debug, Clone)] +pub struct OwnedCheckpoint { + pub last_op_id: i64, + pub write_checkpoint: Option, + pub buckets: BTreeMap, +} + +impl OwnedCheckpoint { + fn from_checkpoint<'a>( + checkpoint: &Checkpoint<'a>, + buckets: BTreeMap, + ) -> Self { + Self { + last_op_id: checkpoint.last_op_id, + write_checkpoint: checkpoint.write_checkpoint, + buckets: buckets, + } + } + + fn apply_diff<'a>(&mut self, diff: &CheckpointDiff<'a>) { + for removed in &diff.removed_buckets { + self.buckets.remove(&**removed); + } + + for updated in &diff.updated_buckets { + let owned = OwnedBucketChecksum::from(updated); + self.buckets.insert(owned.bucket.clone(), owned); + } + + self.last_op_id = diff.last_op_id; + self.write_checkpoint = diff.write_checkpoint; + } +} diff --git a/crates/core/src/sync/sync_status.rs b/crates/core/src/sync/sync_status.rs new file mode 100644 index 0000000..e6744fa --- /dev/null +++ b/crates/core/src/sync/sync_status.rs @@ -0,0 +1,246 @@ +use alloc::{collections::btree_map::BTreeMap, rc::Rc, string::String, vec::Vec}; +use core::{cell::RefCell, hash::BuildHasher}; +use rustc_hash::FxBuildHasher; +use serde::Serialize; +use sqlite_nostd::ResultCode; +use streaming_iterator::StreamingIterator; + +use super::{ + bucket_priority::BucketPriority, interface::Instruction, line::DataLine, + storage_adapter::PersistedBucketProgress, streaming_sync::OwnedCheckpoint, +}; + +/// Information about a progressing download. +#[derive(Serialize, Hash)] +pub struct DownloadSyncStatus { + /// Whether the socket to the sync service is currently open and connected. + /// + /// This starts being true once we receive the first line, and is set to false as the iteration + /// ends. + pub connected: bool, + /// Whether we've requested the client SDK to connect to the socket while not receiving sync + /// lines yet. + pub connecting: bool, + /// Provides stats over which bucket priorities have already been synced (or when they've last + /// been changed). + /// + /// Always sorted by descending [BucketPriority] in [SyncPriorityStatus] (or, in other words, + /// increasing priority numbers). + pub priority_status: Vec, + /// When a download is active (that is, a `checkpoint` or `checkpoint_diff` line has been + /// received), information about how far the download has progressed. + pub downloading: Option, +} + +impl DownloadSyncStatus { + fn debug_assert_priority_status_is_sorted(&self) { + debug_assert!(self + .priority_status + .is_sorted_by(|a, b| a.priority >= b.priority)) + } + + pub fn disconnect(&mut self) { + self.connected = false; + self.connecting = false; + self.downloading = None; + } + + pub fn start_connecting(&mut self) { + self.connected = false; + self.downloading = None; + self.connecting = true; + } + + pub fn mark_connected(&mut self) { + self.connecting = false; + self.connected = true; + } + + /// Transitions state after receiving a checkpoint line. + /// + /// This sets the [downloading] state to include [progress]. + pub fn start_tracking_checkpoint<'a>(&mut self, progress: SyncDownloadProgress) { + self.mark_connected(); + + self.downloading = Some(progress); + } + + /// Increments [SyncDownloadProgress] progress for the given [DataLine]. + pub fn track_line(&mut self, line: &DataLine) { + if let Some(ref mut downloading) = self.downloading { + downloading.increment_download_count(line); + } + } + + pub fn partial_checkpoint_complete(&mut self, priority: BucketPriority, now: Timestamp) { + self.debug_assert_priority_status_is_sorted(); + // We can delete entries with a higher priority because this partial sync includes them. + self.priority_status.retain(|i| i.priority < priority); + self.priority_status.insert( + 0, + SyncPriorityStatus { + priority: priority, + last_synced_at: Some(now), + has_synced: Some(true), + }, + ); + self.debug_assert_priority_status_is_sorted(); + } + + pub fn applied_checkpoint(&mut self, now: Timestamp) { + self.downloading = None; + self.priority_status.clear(); + + self.priority_status.push(SyncPriorityStatus { + priority: BucketPriority::SENTINEL, + last_synced_at: Some(now), + has_synced: Some(true), + }); + } +} + +impl Default for DownloadSyncStatus { + fn default() -> Self { + Self { + connected: false, + connecting: false, + downloading: None, + priority_status: Vec::new(), + } + } +} + +pub struct SyncStatusContainer { + status: Rc>, + last_published_hash: u64, +} + +impl SyncStatusContainer { + pub fn new() -> Self { + Self { + status: Rc::new(RefCell::new(Default::default())), + last_published_hash: 0, + } + } + + /// Invokes a function to update the sync status, then emits an [Instruction::UpdateSyncStatus] + /// if the function did indeed change the status. + pub fn update ()>( + &mut self, + apply: F, + instructions: &mut Vec, + ) { + self.update_only(apply); + self.emit_changes(instructions); + } + + /// Invokes a function to update the sync status without emitting a status event. + pub fn update_only ()>(&self, apply: F) { + let mut status = self.status.borrow_mut(); + apply(&mut *status); + } + + /// If the status has been changed since the last time an [Instruction::UpdateSyncStatus] event + /// was emitted, emit such an event now. + pub fn emit_changes(&mut self, instructions: &mut Vec) { + let status = self.status.borrow(); + let hash = FxBuildHasher.hash_one(&*status); + if hash != self.last_published_hash { + self.last_published_hash = hash; + instructions.push(Instruction::UpdateSyncStatus { + status: self.status.clone(), + }); + } + } +} + +#[repr(transparent)] +#[derive(Serialize, Hash, Clone, Copy)] +pub struct Timestamp(pub i64); + +#[derive(Serialize, Hash)] +pub struct SyncPriorityStatus { + priority: BucketPriority, + last_synced_at: Option, + has_synced: Option, +} + +/// Per-bucket download progress information. +#[derive(Serialize, Hash)] +pub struct BucketProgress { + pub priority: BucketPriority, + pub at_last: i64, + pub since_last: i64, + pub target_count: i64, +} + +#[derive(Serialize, Hash)] +pub struct SyncDownloadProgress { + buckets: BTreeMap, +} + +pub struct SyncProgressFromCheckpoint { + pub progress: SyncDownloadProgress, + pub needs_counter_reset: bool, +} + +impl SyncDownloadProgress { + pub fn for_checkpoint<'a>( + checkpoint: &OwnedCheckpoint, + mut local_progress: impl StreamingIterator< + Item = Result, ResultCode>, + >, + ) -> Result { + let mut buckets = BTreeMap::::new(); + let mut needs_reset = false; + for bucket in checkpoint.buckets.values() { + buckets.insert( + bucket.bucket.clone(), + BucketProgress { + priority: bucket.priority, + target_count: bucket.count.unwrap_or(0), + // Will be filled out later by iterating local_progress + at_last: 0, + since_last: 0, + }, + ); + } + + while let Some(row) = local_progress.next() { + let row = match row { + Ok(row) => row, + Err(e) => return Err(*e), + }; + + let Some(progress) = buckets.get_mut(row.bucket) else { + continue; + }; + + progress.at_last = row.count_at_last; + progress.since_last = row.count_since_last; + + if progress.target_count < row.count_at_last + row.count_since_last { + needs_reset = true; + // Either due to a defrag / sync rule deploy or a compactioon operation, the size + // of the bucket shrank so much that the local ops exceed the ops in the updated + // bucket. We can't possibly report progress in this case (it would overshoot 100%). + for (_, progress) in &mut buckets { + progress.at_last = 0; + progress.since_last = 0; + } + break; + } + } + + Ok(SyncProgressFromCheckpoint { + progress: Self { buckets }, + needs_counter_reset: needs_reset, + }) + } + + pub fn increment_download_count(&mut self, line: &DataLine) { + if let Some(info) = self.buckets.get_mut(&*line.bucket) { + info.since_last += line.data.len() as i64 + } + } +} diff --git a/crates/core/src/sync_local.rs b/crates/core/src/sync_local.rs index 40ddcd5..f884e88 100644 --- a/crates/core/src/sync_local.rs +++ b/crates/core/src/sync_local.rs @@ -4,8 +4,8 @@ use alloc::string::String; use alloc::vec::Vec; use serde::Deserialize; -use crate::bucket_priority::BucketPriority; use crate::error::{PSResult, SQLiteError}; +use crate::sync::BucketPriority; use sqlite_nostd::{self as sqlite, Destructor, ManagedStmt, Value}; use sqlite_nostd::{ColumnType, Connection, ResultCode}; @@ -13,30 +13,29 @@ use crate::ext::SafeManagedStmt; use crate::util::{internal_table_name, quote_internal_name}; pub fn sync_local(db: *mut sqlite::sqlite3, data: &V) -> Result { - let mut operation = SyncOperation::new(db, data)?; + let mut operation = SyncOperation::from_args(db, data)?; operation.apply() } -struct PartialSyncOperation<'a> { +pub struct PartialSyncOperation<'a> { /// The lowest priority part of the partial sync operation. - priority: BucketPriority, + pub priority: BucketPriority, /// The JSON-encoded arguments passed by the client SDK. This includes the priority and a list /// of bucket names in that (and higher) priorities. - args: &'a str, + pub args: &'a str, } -struct SyncOperation<'a> { +pub struct SyncOperation<'a> { db: *mut sqlite::sqlite3, data_tables: BTreeSet, partial: Option>, } impl<'a> SyncOperation<'a> { - fn new(db: *mut sqlite::sqlite3, data: &'a V) -> Result { - return Ok(Self { - db: db, - data_tables: BTreeSet::new(), - partial: match data.value_type() { + fn from_args(db: *mut sqlite::sqlite3, data: &'a V) -> Result { + Ok(Self::new( + db, + match data.value_type() { ColumnType::Text => { let text = data.text(); if text.len() > 0 { @@ -58,7 +57,15 @@ impl<'a> SyncOperation<'a> { } _ => None, }, - }); + )) + } + + pub fn new(db: *mut sqlite::sqlite3, partial: Option>) -> Self { + Self { + db, + data_tables: BTreeSet::new(), + partial, + } } fn can_apply_sync_changes(&self) -> Result { @@ -96,18 +103,27 @@ impl<'a> SyncOperation<'a> { Ok(true) } - fn apply(&mut self) -> Result { + pub fn apply(&mut self) -> Result { if !self.can_apply_sync_changes()? { return Ok(0); } self.collect_tables()?; let statement = self.collect_full_operations()?; - // TODO: cache statements + + // We cache the last insert and delete statements for each row + let mut last_insert_table: Option = None; + let mut last_insert_statement: Option = None; + + let mut last_delete_table: Option = None; + let mut last_delete_statement: Option = None; + + let mut untyped_delete_statement: Option = None; + let mut untyped_insert_statement: Option = None; + while statement.step().into_db_result(self.db)? == ResultCode::ROW { let type_name = statement.column_text(0)?; let id = statement.column_text(1)?; - let buckets = statement.column_int(3); let data = statement.column_text(2); let table_name = internal_table_name(type_name); @@ -115,42 +131,74 @@ impl<'a> SyncOperation<'a> { if self.data_tables.contains(&table_name) { let quoted = quote_internal_name(type_name, false); - if buckets == 0 { + // is_err() is essentially a NULL check here. + // NULL data means no PUT operations found, so we delete the row. + if data.is_err() { // DELETE - let delete_statement = self - .db - .prepare_v2(&format!("DELETE FROM {} WHERE id = ?", quoted)) - .into_db_result(self.db)?; + if last_delete_table.as_deref() != Some("ed) { + // Prepare statement when the table changed + last_delete_statement = Some( + self.db + .prepare_v2(&format!("DELETE FROM {} WHERE id = ?", quoted)) + .into_db_result(self.db)?, + ); + last_delete_table = Some(quoted.clone()); + } + let delete_statement = last_delete_statement.as_mut().unwrap(); + + delete_statement.reset()?; delete_statement.bind_text(1, id, sqlite::Destructor::STATIC)?; delete_statement.exec()?; } else { // INSERT/UPDATE - let insert_statement = self - .db - .prepare_v2(&format!("REPLACE INTO {}(id, data) VALUES(?, ?)", quoted)) - .into_db_result(self.db)?; + if last_insert_table.as_deref() != Some("ed) { + // Prepare statement when the table changed + last_insert_statement = Some( + self.db + .prepare_v2(&format!( + "REPLACE INTO {}(id, data) VALUES(?, ?)", + quoted + )) + .into_db_result(self.db)?, + ); + last_insert_table = Some(quoted.clone()); + } + let insert_statement = last_insert_statement.as_mut().unwrap(); + insert_statement.reset()?; insert_statement.bind_text(1, id, sqlite::Destructor::STATIC)?; insert_statement.bind_text(2, data?, sqlite::Destructor::STATIC)?; insert_statement.exec()?; } } else { - if buckets == 0 { + if data.is_err() { // DELETE - // language=SQLite - let delete_statement = self - .db - .prepare_v2("DELETE FROM ps_untyped WHERE type = ? AND id = ?") - .into_db_result(self.db)?; + if untyped_delete_statement.is_none() { + // Prepare statement on first use + untyped_delete_statement = Some( + self.db + .prepare_v2("DELETE FROM ps_untyped WHERE type = ? AND id = ?") + .into_db_result(self.db)?, + ); + } + let delete_statement = untyped_delete_statement.as_mut().unwrap(); + delete_statement.reset()?; delete_statement.bind_text(1, type_name, sqlite::Destructor::STATIC)?; delete_statement.bind_text(2, id, sqlite::Destructor::STATIC)?; delete_statement.exec()?; } else { // INSERT/UPDATE - // language=SQLite - let insert_statement = self - .db - .prepare_v2("REPLACE INTO ps_untyped(type, id, data) VALUES(?, ?, ?)") - .into_db_result(self.db)?; + if untyped_insert_statement.is_none() { + // Prepare statement on first use + untyped_insert_statement = Some( + self.db + .prepare_v2( + "REPLACE INTO ps_untyped(type, id, data) VALUES(?, ?, ?)", + ) + .into_db_result(self.db)?, + ); + } + let insert_statement = untyped_insert_statement.as_mut().unwrap(); + insert_statement.reset()?; insert_statement.bind_text(1, type_name, sqlite::Destructor::STATIC)?; insert_statement.bind_text(2, id, sqlite::Destructor::STATIC)?; insert_statement.bind_text(3, data?, sqlite::Destructor::STATIC)?; @@ -185,32 +233,29 @@ impl<'a> SyncOperation<'a> { Ok(match &self.partial { None => { // Complete sync + // See dart/test/sync_local_performance_test.dart for an annotated version of this query. self.db .prepare_v2( "\ --- 1. Filter oplog by the ops added but not applied yet (oplog b). --- SELECT DISTINCT / UNION is important for cases with many duplicate ids. WITH updated_rows AS ( - SELECT DISTINCT b.row_type, b.row_id FROM ps_buckets AS buckets - CROSS JOIN ps_oplog AS b ON b.bucket = buckets.id - AND (b.op_id > buckets.last_applied_op) - UNION SELECT row_type, row_id FROM ps_updated_rows + SELECT b.row_type, b.row_id FROM ps_buckets AS buckets + CROSS JOIN ps_oplog AS b ON b.bucket = buckets.id + AND (b.op_id > buckets.last_applied_op) + UNION ALL SELECT row_type, row_id FROM ps_updated_rows ) --- 3. Group the objects from different buckets together into a single one (ops). -SELECT b.row_type as type, - b.row_id as id, - r.data as data, - count(r.bucket) as buckets, - /* max() affects which row is used for 'data' */ - max(r.op_id) as op_id --- 2. Find *all* current ops over different buckets for those objects (oplog r). -FROM updated_rows b - LEFT OUTER JOIN ps_oplog AS r - ON r.row_type = b.row_type - AND r.row_id = b.row_id --- Group for (3) -GROUP BY b.row_type, b.row_id", +SELECT + b.row_type, + b.row_id, + ( + SELECT iif(max(r.op_id), r.data, null) + FROM ps_oplog r + WHERE r.row_type = b.row_type + AND r.row_id = b.row_id + + ) as data + FROM updated_rows b + GROUP BY b.row_type, b.row_id;", ) .into_db_result(self.db)? } @@ -220,33 +265,38 @@ GROUP BY b.row_type, b.row_id", .prepare_v2( "\ -- 1. Filter oplog by the ops added but not applied yet (oplog b). --- SELECT DISTINCT / UNION is important for cases with many duplicate ids. +-- We do not do any DISTINCT operation here, since that introduces a temp b-tree. +-- We filter out duplicates using the GROUP BY below. WITH involved_buckets (id) AS MATERIALIZED ( SELECT id FROM ps_buckets WHERE ?1 IS NULL OR name IN (SELECT value FROM json_each(json_extract(?1, '$.buckets'))) ), updated_rows AS ( - SELECT DISTINCT FALSE as local, b.row_type, b.row_id FROM ps_buckets AS buckets - CROSS JOIN ps_oplog AS b ON b.bucket = buckets.id AND (b.op_id > buckets.last_applied_op) - WHERE buckets.id IN (SELECT id FROM involved_buckets) + SELECT b.row_type, b.row_id FROM ps_buckets AS buckets + CROSS JOIN ps_oplog AS b ON b.bucket = buckets.id + AND (b.op_id > buckets.last_applied_op) + WHERE buckets.id IN (SELECT id FROM involved_buckets) ) --- 3. Group the objects from different buckets together into a single one (ops). -SELECT b.row_type as type, - b.row_id as id, - r.data as data, - count(r.bucket) as buckets, - /* max() affects which row is used for 'data' */ - max(r.op_id) as op_id -- 2. Find *all* current ops over different buckets for those objects (oplog r). -FROM updated_rows b - LEFT OUTER JOIN ps_oplog AS r - ON r.row_type = b.row_type - AND r.row_id = b.row_id - AND r.bucket IN (SELECT id FROM involved_buckets) --- Group for (3) -GROUP BY b.row_type, b.row_id", +SELECT + b.row_type, + b.row_id, + ( + -- 3. For each unique row, select the data from the latest oplog entry. + -- The max(r.op_id) clause is used to select the latest oplog entry. + -- The iif is to avoid the max(r.op_id) column ending up in the results. + SELECT iif(max(r.op_id), r.data, null) + FROM ps_oplog r + WHERE r.row_type = b.row_type + AND r.row_id = b.row_id + AND r.bucket IN (SELECT id FROM involved_buckets) + + ) as data + FROM updated_rows b + -- Group for (2) + GROUP BY b.row_type, b.row_id;", ) .into_db_result(self.db)?; stmt.bind_text(1, partial.args, Destructor::STATIC)?; diff --git a/crates/core/src/sync_types.rs b/crates/core/src/sync_types.rs deleted file mode 100644 index 060dd25..0000000 --- a/crates/core/src/sync_types.rs +++ /dev/null @@ -1,22 +0,0 @@ -use alloc::string::String; -use alloc::vec::Vec; -use serde::{Deserialize, Serialize}; - -use crate::util::{deserialize_optional_string_to_i64, deserialize_string_to_i64}; - -#[derive(Serialize, Deserialize, Debug)] -pub struct Checkpoint { - #[serde(deserialize_with = "deserialize_string_to_i64")] - pub last_op_id: i64, - #[serde(default)] - #[serde(deserialize_with = "deserialize_optional_string_to_i64")] - pub write_checkpoint: Option, - pub buckets: Vec, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct BucketChecksum { - pub bucket: String, - pub checksum: i32, - pub priority: Option, -} diff --git a/crates/core/src/util.rs b/crates/core/src/util.rs index 2e50951..a9e0842 100644 --- a/crates/core/src/util.rs +++ b/crates/core/src/util.rs @@ -3,11 +3,9 @@ extern crate alloc; use alloc::format; use alloc::string::String; -use serde::Deserialize; -use serde_json as json; - #[cfg(not(feature = "getrandom"))] use crate::sqlite; +use serde::de::Visitor; use uuid::Uuid; @@ -46,25 +44,56 @@ pub fn deserialize_string_to_i64<'de, D>(deserializer: D) -> Result, { - let value = json::Value::deserialize(deserializer)?; + struct ValueVisitor; + + impl<'de> Visitor<'de> for ValueVisitor { + type Value = i64; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + formatter.write_str("a string representation of a number") + } - match value { - json::Value::String(s) => s.parse::().map_err(serde::de::Error::custom), - _ => Err(serde::de::Error::custom("Expected a string.")), + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse::().map_err(serde::de::Error::custom) + } } + + // Using a custom visitor here to avoid an intermediate string allocation + deserializer.deserialize_str(ValueVisitor) } pub fn deserialize_optional_string_to_i64<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, { - let value = json::Value::deserialize(deserializer)?; - - match value { - json::Value::Null => Ok(None), - json::Value::String(s) => s.parse::().map(Some).map_err(serde::de::Error::custom), - _ => Err(serde::de::Error::custom("Expected a string or null.")), + struct ValueVisitor; + + impl<'de> Visitor<'de> for ValueVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + formatter.write_str("a string or null") + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(Some(deserialize_string_to_i64(deserializer)?)) + } } + + deserializer.deserialize_option(ValueVisitor) } // Use getrandom crate to generate UUID. diff --git a/crates/core/src/uuid.rs b/crates/core/src/uuid.rs index db617f9..82d9046 100644 --- a/crates/core/src/uuid.rs +++ b/crates/core/src/uuid.rs @@ -1,6 +1,5 @@ extern crate alloc; -use alloc::format; use alloc::string::String; use alloc::string::ToString; use core::ffi::c_int; diff --git a/crates/core/src/views.rs b/crates/core/src/views.rs index dea0b2b..03cbdd8 100644 --- a/crates/core/src/views.rs +++ b/crates/core/src/views.rs @@ -6,49 +6,41 @@ use alloc::string::String; use alloc::vec::Vec; use core::ffi::c_int; use core::fmt::Write; -use streaming_iterator::StreamingIterator; use sqlite::{Connection, Context, ResultCode, Value}; use sqlite_nostd::{self as sqlite}; use crate::create_sqlite_text_fn; use crate::error::SQLiteError; -use crate::schema::{ColumnInfo, ColumnNameAndTypeStatement, DiffIncludeOld, TableInfo}; +use crate::schema::{DiffIncludeOld, Table}; use crate::util::*; fn powersync_view_sql_impl( - ctx: *mut sqlite::context, + _ctx: *mut sqlite::context, args: &[*mut sqlite::value], ) -> Result { - let db = ctx.db_handle(); - let table = args[0].text(); - let table_info = TableInfo::parse_from(db, table)?; + let table_info = Table::from_json(args[0].text())?; let name = &table_info.name; - let view_name = &table_info.view_name; + let view_name = &table_info.view_name(); let local_only = table_info.flags.local_only(); let include_metadata = table_info.flags.include_metadata(); let quoted_name = quote_identifier(view_name); let internal_name = quote_internal_name(name, local_only); - let mut columns = ColumnNameAndTypeStatement::new(db, table)?; - let mut iter = columns.streaming_iter(); - let mut column_names_quoted: Vec = alloc::vec![]; let mut column_values: Vec = alloc::vec![]; column_names_quoted.push(quote_identifier("id")); column_values.push(String::from("id")); - while let Some(row) = iter.next() { - let ColumnInfo { name, type_name } = row.clone()?; - column_names_quoted.push(quote_identifier(name)); + for column in &table_info.columns { + column_names_quoted.push(quote_identifier(&column.name)); - let foo = format!( + column_values.push(format!( "CAST(json_extract(data, {:}) as {:})", - quote_json_path(name), - type_name - ); - column_values.push(foo); + quote_json_path(&column.name), + &column.type_name + )); } if include_metadata { @@ -77,14 +69,13 @@ create_sqlite_text_fn!( ); fn powersync_trigger_delete_sql_impl( - ctx: *mut sqlite::context, + _ctx: *mut sqlite::context, args: &[*mut sqlite::value], ) -> Result { - let table = args[0].text(); - let table_info = TableInfo::parse_from(ctx.db_handle(), table)?; + let table_info = Table::from_json(args[0].text())?; let name = &table_info.name; - let view_name = &table_info.view_name; + let view_name = &table_info.view_name(); let local_only = table_info.flags.local_only(); let insert_only = table_info.flags.insert_only(); @@ -93,23 +84,14 @@ fn powersync_trigger_delete_sql_impl( let trigger_name = quote_identifier_prefixed("ps_view_delete_", view_name); let type_string = quote_string(name); - let db = ctx.db_handle(); - let old_fragment: Cow<'static, str> = match table_info.diff_include_old { + let old_fragment: Cow<'static, str> = match &table_info.diff_include_old { Some(include_old) => { - let mut columns = ColumnNameAndTypeStatement::new(db, table)?; - let json = match include_old { DiffIncludeOld::OnlyForColumns { columns } => { - let mut iterator = columns.iter(); - let mut columns = - streaming_iterator::from_fn(|| -> Option> { - Some(Ok(iterator.next()?.as_str())) - }); - - json_object_fragment("OLD", &mut columns) + json_object_fragment("OLD", &mut columns.iter().map(|c| c.as_str())) } DiffIncludeOld::ForAllColumns => { - json_object_fragment("OLD", &mut columns.names_iter()) + json_object_fragment("OLD", &mut table_info.column_names()) } }?; @@ -179,15 +161,13 @@ create_sqlite_text_fn!( ); fn powersync_trigger_insert_sql_impl( - ctx: *mut sqlite::context, + _ctx: *mut sqlite::context, args: &[*mut sqlite::value], ) -> Result { - let table = args[0].text(); - - let table_info = TableInfo::parse_from(ctx.db_handle(), table)?; + let table_info = Table::from_json(args[0].text())?; let name = &table_info.name; - let view_name = &table_info.view_name; + let view_name = &table_info.view_name(); let local_only = table_info.flags.local_only(); let insert_only = table_info.flags.insert_only(); @@ -196,10 +176,7 @@ fn powersync_trigger_insert_sql_impl( let trigger_name = quote_identifier_prefixed("ps_view_insert_", view_name); let type_string = quote_string(name); - let local_db = ctx.db_handle(); - - let mut columns = ColumnNameAndTypeStatement::new(local_db, table)?; - let json_fragment = json_object_fragment("NEW", &mut columns.names_iter())?; + let json_fragment = json_object_fragment("NEW", &mut table_info.column_names())?; let metadata_fragment = if table_info.flags.include_metadata() { ", 'metadata', NEW._metadata" @@ -258,15 +235,13 @@ create_sqlite_text_fn!( ); fn powersync_trigger_update_sql_impl( - ctx: *mut sqlite::context, + _ctx: *mut sqlite::context, args: &[*mut sqlite::value], ) -> Result { - let table = args[0].text(); - - let table_info = TableInfo::parse_from(ctx.db_handle(), table)?; + let table_info = Table::from_json(args[0].text())?; let name = &table_info.name; - let view_name = &table_info.view_name; + let view_name = &table_info.view_name(); let insert_only = table_info.flags.insert_only(); let local_only = table_info.flags.local_only(); @@ -275,23 +250,16 @@ fn powersync_trigger_update_sql_impl( let trigger_name = quote_identifier_prefixed("ps_view_update_", view_name); let type_string = quote_string(name); - let db = ctx.db_handle(); - let mut columns = ColumnNameAndTypeStatement::new(db, table)?; - let json_fragment_new = json_object_fragment("NEW", &mut columns.names_iter())?; - let json_fragment_old = json_object_fragment("OLD", &mut columns.names_iter())?; + let json_fragment_new = json_object_fragment("NEW", &mut table_info.column_names())?; + let json_fragment_old = json_object_fragment("OLD", &mut table_info.column_names())?; let mut old_values_fragment = match &table_info.diff_include_old { None => None, Some(DiffIncludeOld::ForAllColumns) => Some(json_fragment_old.clone()), - Some(DiffIncludeOld::OnlyForColumns { columns }) => { - let mut iterator = columns.iter(); - let mut columns = - streaming_iterator::from_fn(|| -> Option> { - Some(Ok(iterator.next()?.as_str())) - }); - - Some(json_object_fragment("OLD", &mut columns)?) - } + Some(DiffIncludeOld::OnlyForColumns { columns }) => Some(json_object_fragment( + "OLD", + &mut columns.iter().map(|c| c.as_str()), + )?), }; if table_info.flags.include_old_only_when_changed() { @@ -301,15 +269,9 @@ fn powersync_trigger_update_sql_impl( let filtered_new_fragment = match &table_info.diff_include_old { // When include_old_only_when_changed is combined with a column filter, make sure we // only include the powersync_diff of columns matched by the filter. - Some(DiffIncludeOld::OnlyForColumns { columns }) => { - let mut iterator = columns.iter(); - let mut columns = - streaming_iterator::from_fn(|| -> Option> { - Some(Ok(iterator.next()?.as_str())) - }); - - Cow::Owned(json_object_fragment("NEW", &mut columns)?) - } + Some(DiffIncludeOld::OnlyForColumns { columns }) => Cow::Owned( + json_object_fragment("NEW", &mut columns.iter().map(|c| c.as_str()))?, + ), _ => Cow::Borrowed(json_fragment_new.as_str()), }; @@ -444,7 +406,7 @@ pub fn register(db: *mut sqlite::sqlite3) -> Result<(), ResultCode> { /// Example output with prefix "NEW": "json_object('id', NEW.id, 'name', NEW.name, 'age', NEW.age)". fn json_object_fragment<'a>( prefix: &str, - name_results: &mut dyn StreamingIterator>, + name_results: &mut dyn Iterator, ) -> Result { // floor(SQLITE_MAX_FUNCTION_ARG / 2). // To keep databases portable, we use the default limit of 100 args for this, @@ -452,9 +414,7 @@ fn json_object_fragment<'a>( const MAX_ARG_COUNT: usize = 50; let mut column_names_quoted: Vec = alloc::vec![]; - while let Some(row) = name_results.next() { - let name = (*row)?; - + while let Some(name) = name_results.next() { let quoted: String = format!( "{:}, {:}.{:}", quote_string(name), diff --git a/crates/loadable/src/lib.rs b/crates/loadable/src/lib.rs index c6ca649..9e5a9de 100644 --- a/crates/loadable/src/lib.rs +++ b/crates/loadable/src/lib.rs @@ -3,7 +3,6 @@ #![feature(core_intrinsics)] #![allow(internal_features)] #![feature(lang_items)] -#![feature(error_in_core)] extern crate alloc; diff --git a/dart/benchmark/apply_lines.dart b/dart/benchmark/apply_lines.dart new file mode 100644 index 0000000..ee0654b --- /dev/null +++ b/dart/benchmark/apply_lines.dart @@ -0,0 +1,48 @@ +import 'dart:io'; +import 'dart:typed_data'; + +import '../test/utils/native_test_utils.dart'; + +/// Usage: dart run benchmark/apply_lines.dart path/to/lines.bin +/// +/// This creates a new in-memory database and applies concatenated BSON sync +/// lines from a file. +void main(List args) { + if (args.length != 1) { + throw 'Usage: dart run benchmark/apply_lines.dart path/to/lines.bin'; + } + + final [path] = args; + final file = File(path).openSync(); + final db = openTestDatabase(); + + db + ..execute('select powersync_init()') + ..execute('select powersync_control(?, null)', ['start']); + + final stopwatch = Stopwatch()..start(); + + final lengthBuffer = Uint8List(4); + while (file.positionSync() < file.lengthSync()) { + // BSON document: + final bytesRead = file.readIntoSync(lengthBuffer); + if (bytesRead != 4) { + throw 'short read, expected length'; + } + final length = lengthBuffer.buffer.asByteData().getInt32(0, Endian.little); + file.setPositionSync(file.positionSync() - 4); + + final syncLineBson = file.readSync(length); + if (syncLineBson.length != length) { + throw 'short read for bson document'; + } + + db + ..execute('BEGIN') + ..execute('SELECT powersync_control(?, ?)', ['line_binary', syncLineBson]) + ..execute('COMMIT;'); + } + + stopwatch.stop(); + print('Applying $path took ${stopwatch.elapsed}'); +} diff --git a/dart/pubspec.lock b/dart/pubspec.lock index 915c28e..5000e7e 100644 --- a/dart/pubspec.lock +++ b/dart/pubspec.lock @@ -41,6 +41,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.1.2" + bson: + dependency: "direct main" + description: + name: bson + sha256: "9b761248a3494fea594aecf5d6f369b5f04d7b082aa2b8c06579ade77f1a7e47" + url: "https://pub.dev" + source: hosted + version: "5.0.6" cli_config: dependency: transitive description: @@ -66,7 +74,7 @@ packages: source: hosted version: "1.19.1" convert: - dependency: transitive + dependency: "direct dev" description: name: convert sha256: b30acd5944035672bc15c6b7a8b47d773e41e2f17de064350988c5d02adb1c68 @@ -89,6 +97,14 @@ packages: url: "https://pub.dev" source: hosted version: "3.0.6" + decimal: + dependency: transitive + description: + name: decimal + sha256: "28239b8b929c1bd8618702e6dbc96e2618cf99770bbe9cb040d6cf56a11e4ec3" + url: "https://pub.dev" + source: hosted + version: "3.2.1" fake_async: dependency: "direct dev" description: @@ -113,6 +129,14 @@ packages: url: "https://pub.dev" source: hosted version: "7.0.1" + fixnum: + dependency: transitive + description: + name: fixnum + sha256: b6dc7065e46c974bc7c5f143080a6764ec7a4be6da1285ececdc37be96de53be + url: "https://pub.dev" + source: hosted + version: "1.1.1" frontend_server_client: dependency: transitive description: @@ -145,6 +169,14 @@ packages: url: "https://pub.dev" source: hosted version: "4.1.2" + intl: + dependency: transitive + description: + name: intl + sha256: "3df61194eb431efc39c4ceba583b95633a403f46c9fd341e550ce0bfa50e9aa5" + url: "https://pub.dev" + source: hosted + version: "0.20.2" io: dependency: transitive description: @@ -178,7 +210,7 @@ packages: source: hosted version: "0.12.17" meta: - dependency: transitive + dependency: "direct dev" description: name: meta sha256: e3641ec5d63ebf0d9b41bd43201a66e3fc79a65db5f61fc181f04cd27aab950c @@ -209,6 +241,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.2.0" + packages_extensions: + dependency: transitive + description: + name: packages_extensions + sha256: "1fb328695a9828c80d275ce1650a2bb5947690070de082dfa1dfac7429378daf" + url: "https://pub.dev" + source: hosted + version: "0.1.1" path: dependency: transitive description: @@ -225,6 +265,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.5.1" + power_extensions: + dependency: transitive + description: + name: power_extensions + sha256: ad0e8b2420090d996fe8b7fd32cdf02b9b924b6d4fc0fb0b559ff6aa5e24d5b0 + url: "https://pub.dev" + source: hosted + version: "0.2.3" pub_semver: dependency: transitive description: @@ -233,6 +281,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.2.0" + rational: + dependency: transitive + description: + name: rational + sha256: cb808fb6f1a839e6fc5f7d8cb3b0a10e1db48b3be102de73938c627f0b636336 + url: "https://pub.dev" + source: hosted + version: "2.2.3" shelf: dependency: transitive description: @@ -289,14 +345,22 @@ packages: url: "https://pub.dev" source: hosted version: "1.10.1" + sprintf: + dependency: transitive + description: + name: sprintf + sha256: "1fc9ffe69d4df602376b52949af107d8f5703b77cda567c4d7d86a0693120f23" + url: "https://pub.dev" + source: hosted + version: "7.0.0" sqlite3: dependency: "direct main" description: name: sqlite3 - sha256: "310af39c40dd0bb2058538333c9d9840a2725ae0b9f77e4fd09ad6696aa8f66e" + sha256: c0503c69b44d5714e6abbf4c1f51a3c3cc42b75ce785f44404765e4635481d38 url: "https://pub.dev" source: hosted - version: "2.7.5" + version: "2.7.6" sqlite3_test: dependency: "direct dev" description: @@ -369,6 +433,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.4.0" + uuid: + dependency: transitive + description: + name: uuid + sha256: a5be9ef6618a7ac1e964353ef476418026db906c4facdedaa299b7a2e71690ff + url: "https://pub.dev" + source: hosted + version: "4.5.1" vm_service: dependency: transitive description: diff --git a/dart/pubspec.yaml b/dart/pubspec.yaml index 1867319..40b4648 100644 --- a/dart/pubspec.yaml +++ b/dart/pubspec.yaml @@ -5,9 +5,13 @@ description: Tests for powersync-sqlite-core environment: sdk: ^3.4.0 dependencies: - sqlite3: ^2.4.5 + sqlite3: ^2.7.6 + bson: ^5.0.5 + dev_dependencies: test: ^1.25.0 file: ^7.0.1 sqlite3_test: ^0.1.1 fake_async: ^1.3.3 + convert: ^3.1.2 + meta: ^1.16.0 diff --git a/dart/test/goldens/simple_iteration.json b/dart/test/goldens/simple_iteration.json new file mode 100644 index 0000000..ad3f358 --- /dev/null +++ b/dart/test/goldens/simple_iteration.json @@ -0,0 +1,168 @@ +[ + { + "operation": "start", + "data": null, + "output": [ + { + "UpdateSyncStatus": { + "status": { + "connected": false, + "connecting": true, + "priority_status": [], + "downloading": null + } + } + }, + { + "EstablishSyncStream": { + "request": { + "buckets": [], + "include_checksum": true, + "raw_data": true, + "binary_data": true, + "client_id": "test-test-test-test", + "parameters": null + } + } + } + ] + }, + { + "operation": "line_text", + "data": { + "checkpoint": { + "last_op_id": "1", + "write_checkpoint": null, + "buckets": [ + { + "bucket": "a", + "checksum": 0, + "priority": 3, + "count": 1 + } + ] + } + }, + "output": [ + { + "UpdateSyncStatus": { + "status": { + "connected": true, + "connecting": false, + "priority_status": [], + "downloading": { + "buckets": { + "a": { + "priority": 3, + "at_last": 0, + "since_last": 0, + "target_count": 1 + } + } + } + } + } + } + ] + }, + { + "operation": "line_text", + "data": { + "token_expires_in": 60 + }, + "output": [] + }, + { + "operation": "line_text", + "data": { + "data": { + "bucket": "a", + "has_more": false, + "after": null, + "next_after": null, + "data": [ + { + "op_id": "1", + "op": "PUT", + "object_type": "items", + "object_id": "1", + "checksum": 0, + "data": "{\"col\":\"hi\"}" + } + ] + } + }, + "output": [ + { + "UpdateSyncStatus": { + "status": { + "connected": true, + "connecting": false, + "priority_status": [], + "downloading": { + "buckets": { + "a": { + "priority": 3, + "at_last": 0, + "since_last": 1, + "target_count": 1 + } + } + } + } + } + } + ] + }, + { + "operation": "line_text", + "data": { + "checkpoint_complete": { + "last_op_id": "1" + } + }, + "output": [ + { + "LogLine": { + "severity": "DEBUG", + "line": "Validated and applied checkpoint" + } + }, + { + "FlushFileSystem": {} + }, + { + "DidCompleteSync": {} + }, + { + "UpdateSyncStatus": { + "status": { + "connected": true, + "connecting": false, + "priority_status": [ + { + "priority": 2147483647, + "last_synced_at": 1740823200, + "has_synced": true + } + ], + "downloading": null + } + } + } + ] + }, + { + "operation": "line_text", + "data": { + "token_expires_in": 10 + }, + "output": [ + { + "FetchCredentials": { + "did_expire": false + } + } + ] + } +] \ No newline at end of file diff --git a/dart/test/goldens/starting_stream.json b/dart/test/goldens/starting_stream.json new file mode 100644 index 0000000..96e505c --- /dev/null +++ b/dart/test/goldens/starting_stream.json @@ -0,0 +1,36 @@ +[ + { + "operation": "start", + "data": { + "parameters": { + "foo": "bar" + } + }, + "output": [ + { + "UpdateSyncStatus": { + "status": { + "connected": false, + "connecting": true, + "priority_status": [], + "downloading": null + } + } + }, + { + "EstablishSyncStream": { + "request": { + "buckets": [], + "include_checksum": true, + "raw_data": true, + "binary_data": true, + "client_id": "test-test-test-test", + "parameters": { + "foo": "bar" + } + } + } + } + ] + } +] \ No newline at end of file diff --git a/dart/test/js_key_encoding_test.dart b/dart/test/js_key_encoding_test.dart new file mode 100644 index 0000000..dd86e06 --- /dev/null +++ b/dart/test/js_key_encoding_test.dart @@ -0,0 +1,79 @@ +import 'dart:convert'; + +import 'package:file/local.dart'; +import 'package:sqlite3/common.dart'; +import 'package:sqlite3/sqlite3.dart'; +import 'package:sqlite3_test/sqlite3_test.dart'; +import 'package:test/test.dart'; + +import 'utils/native_test_utils.dart'; + +void main() { + // Needs an unique name per test file to avoid concurrency issues + final vfs = TestSqliteFileSystem( + fs: const LocalFileSystem(), name: 'js-key-encoding-test-vfs'); + late CommonDatabase db; + + setUpAll(() { + loadExtension(); + sqlite3.registerVirtualFileSystem(vfs, makeDefault: false); + }); + tearDownAll(() => sqlite3.unregisterVirtualFileSystem(vfs)); + + setUp(() async { + db = openTestDatabase(vfs: vfs) + ..select('select powersync_init();') + ..select('select powersync_replace_schema(?)', [json.encode(_schema)]); + }); + + tearDown(() { + db.dispose(); + }); + + test('can fix JS key encoding', () { + db.execute('insert into powersync_operations (op, data) VALUES (?, ?);', [ + 'save', + json.encode({ + 'buckets': [ + { + 'bucket': 'a', + 'data': [ + { + 'op_id': '1', + 'op': 'PUT', + 'object_type': 'items', + 'object_id': '1', + 'subkey': json.encode('subkey'), + 'checksum': 0, + 'data': json.encode({'col': 'a'}), + } + ], + } + ], + }) + ]); + + db.execute('INSERT INTO powersync_operations(op, data) VALUES (?, ?)', + ['sync_local', null]); + var [row] = db.select('select * from ps_oplog'); + expect(row['key'], 'items/1/"subkey"'); + + // Apply migration + db.execute( + 'UPDATE ps_oplog SET key = powersync_remove_duplicate_key_encoding(key);'); + + [row] = db.select('select * from ps_oplog'); + expect(row['key'], 'items/1/subkey'); + }); +} + +const _schema = { + 'tables': [ + { + 'name': 'items', + 'columns': [ + {'name': 'col', 'type': 'text'} + ], + } + ] +}; diff --git a/dart/test/legacy_sync_test.dart b/dart/test/legacy_sync_test.dart new file mode 100644 index 0000000..e67f7bb --- /dev/null +++ b/dart/test/legacy_sync_test.dart @@ -0,0 +1,341 @@ +import 'dart:convert'; + +import 'package:fake_async/fake_async.dart'; +import 'package:file/local.dart'; +import 'package:sqlite3/common.dart'; +import 'package:sqlite3/sqlite3.dart'; +import 'package:sqlite3_test/sqlite3_test.dart'; +import 'package:test/test.dart'; + +import 'utils/native_test_utils.dart'; + +/// Tests that the older sync interfaces requiring clients to decode and handle +/// sync lines still work. +void main() { + final vfs = TestSqliteFileSystem( + fs: const LocalFileSystem(), name: 'legacy-sync-test'); + + setUpAll(() { + loadExtension(); + sqlite3.registerVirtualFileSystem(vfs, makeDefault: false); + }); + tearDownAll(() => sqlite3.unregisterVirtualFileSystem(vfs)); + + group('sync tests', () { + late CommonDatabase db; + + setUp(() async { + db = openTestDatabase(vfs: vfs) + ..select('select powersync_init();') + ..select('select powersync_replace_schema(?)', [json.encode(_schema)]); + }); + + tearDown(() { + db.dispose(); + }); + + void pushSyncData( + String bucket, + String opId, + String rowId, + Object op, + Object? data, { + Object? descriptions = _bucketDescriptions, + }) { + final encoded = json.encode({ + 'buckets': [ + { + 'bucket': bucket, + 'data': [ + { + 'op_id': opId, + 'op': op, + 'object_type': 'items', + 'object_id': rowId, + 'checksum': 0, + 'data': json.encode(data), + } + ], + } + ], + if (descriptions != null) 'descriptions': descriptions, + }); + + db.execute('insert into powersync_operations (op, data) VALUES (?, ?);', + ['save', encoded]); + } + + bool pushCheckpointComplete( + String lastOpId, String? writeCheckpoint, List checksums, + {int? priority}) { + final [row] = db.select('select powersync_validate_checkpoint(?) as r;', [ + json.encode({ + 'last_op_id': lastOpId, + 'write_checkpoint': writeCheckpoint, + 'buckets': [ + for (final cs in checksums.cast>()) + if (priority == null || cs['priority'] <= priority) cs + ], + 'priority': priority, + }) + ]); + + final decoded = json.decode(row['r']); + if (decoded['valid'] != true) { + fail(row['r']); + } + + db.execute( + 'UPDATE ps_buckets SET last_op = ? WHERE name IN (SELECT json_each.value FROM json_each(?))', + [ + lastOpId, + json.encode(checksums.map((e) => (e as Map)['bucket']).toList()) + ], + ); + + db.execute('INSERT INTO powersync_operations(op, data) VALUES (?, ?)', [ + 'sync_local', + priority != null + ? jsonEncode({ + 'priority': priority, + 'buckets': [ + for (final cs in checksums.cast>()) + if (cs['priority'] <= priority) cs['bucket'] + ], + }) + : null, + ]); + return db.lastInsertRowId == 1; + } + + ResultSet fetchRows() { + return db.select('select * from items'); + } + + test('does not publish until reaching checkpoint', () { + expect(fetchRows(), isEmpty); + pushSyncData('prio1', '1', 'row-0', 'PUT', {'col': 'hi'}); + expect(fetchRows(), isEmpty); + + expect( + pushCheckpointComplete( + '1', null, [_bucketChecksum('prio1', 1, checksum: 0)]), + isTrue); + expect(fetchRows(), [ + {'id': 'row-0', 'col': 'hi'} + ]); + }); + + test('does not publish with pending local data', () { + expect(fetchRows(), isEmpty); + db.execute("insert into items (id, col) values ('local', 'data');"); + expect(fetchRows(), isNotEmpty); + + pushSyncData('prio1', '1', 'row-0', 'PUT', {'col': 'hi'}); + expect( + pushCheckpointComplete( + '1', null, [_bucketChecksum('prio1', 1, checksum: 0)]), + isFalse); + expect(fetchRows(), [ + {'id': 'local', 'col': 'data'} + ]); + }); + + test('publishes with local data for prio=0 buckets', () { + expect(fetchRows(), isEmpty); + db.execute("insert into items (id, col) values ('local', 'data');"); + expect(fetchRows(), isNotEmpty); + + pushSyncData('prio0', '1', 'row-0', 'PUT', {'col': 'hi'}); + expect( + pushCheckpointComplete( + '1', + null, + [_bucketChecksum('prio0', 0, checksum: 0)], + priority: 0, + ), + isTrue, + ); + expect(fetchRows(), [ + {'id': 'local', 'col': 'data'}, + {'id': 'row-0', 'col': 'hi'}, + ]); + }); + + test('can publish partial checkpoints under different priorities', () { + for (var i = 0; i < 4; i++) { + pushSyncData('prio$i', '1', 'row-$i', 'PUT', {'col': '$i'}); + } + expect(fetchRows(), isEmpty); + + // Simulate a partial checkpoint complete for each of the buckets. + for (var i = 0; i < 4; i++) { + expect( + pushCheckpointComplete( + '1', + null, + [ + for (var j = 0; j <= 4; j++) + _bucketChecksum( + 'prio$j', + j, + // Give buckets outside of the current priority a wrong + // checksum. They should not be validated yet. + checksum: j <= i ? 0 : 1234, + ), + ], + priority: i, + ), + isTrue, + ); + + expect(fetchRows(), [ + for (var j = 0; j <= i; j++) {'id': 'row-$j', 'col': '$j'}, + ]); + + expect(db.select('select 1 from ps_sync_state where priority = ?', [i]), + isNotEmpty); + // A sync at this priority includes all higher priorities too, so they + // should be cleared. + expect(db.select('select 1 from ps_sync_state where priority < ?', [i]), + isEmpty); + } + }); + + test('can sync multiple times', () { + fakeAsync((controller) { + for (var i = 0; i < 10; i++) { + for (var prio in const [1, 2, 3, null]) { + pushCheckpointComplete('1', null, [], priority: prio); + + // Make sure there's only a single row in last_synced_at + expect( + db.select( + "SELECT datetime(last_synced_at, 'localtime') AS last_synced_at FROM ps_sync_state WHERE priority = ?", + [prio ?? 2147483647]), + [ + {'last_synced_at': '2025-03-01 ${10 + i}:00:00'} + ], + ); + + if (prio == null) { + expect( + db.select( + "SELECT datetime(powersync_last_synced_at(), 'localtime') AS last_synced_at"), + [ + {'last_synced_at': '2025-03-01 ${10 + i}:00:00'} + ], + ); + } + } + + controller.elapse(const Duration(hours: 1)); + } + }, initialTime: DateTime(2025, 3, 1, 10)); + }); + + test('clearing database clears sync status', () { + pushSyncData('prio1', '1', 'row-0', 'PUT', {'col': 'hi'}); + + expect( + pushCheckpointComplete( + '1', null, [_bucketChecksum('prio1', 1, checksum: 0)]), + isTrue); + expect(db.select('SELECT powersync_last_synced_at() AS r').single, + {'r': isNotNull}); + expect(db.select('SELECT priority FROM ps_sync_state').single, + {'priority': 2147483647}); + + db.execute('SELECT powersync_clear(0)'); + expect(db.select('SELECT powersync_last_synced_at() AS r').single, + {'r': isNull}); + expect(db.select('SELECT * FROM ps_sync_state'), hasLength(0)); + }); + + test('tracks download progress', () { + const bucket = 'bkt'; + void expectProgress(int atLast, int sinceLast) { + final [row] = db.select( + 'SELECT count_at_last, count_since_last FROM ps_buckets WHERE name = ?', + [bucket], + ); + final [actualAtLast, actualSinceLast] = row.values; + + expect(actualAtLast, atLast, reason: 'count_at_last mismatch'); + expect(actualSinceLast, sinceLast, reason: 'count_since_last mismatch'); + } + + pushSyncData(bucket, '1', 'row-0', 'PUT', {'col': 'hi'}); + expectProgress(0, 1); + + pushSyncData(bucket, '2', 'row-1', 'PUT', {'col': 'hi'}); + expectProgress(0, 2); + + expect( + pushCheckpointComplete( + '2', + null, + [_bucketChecksum(bucket, 1, checksum: 0)], + priority: 1, + ), + isTrue, + ); + + // Running partial or complete checkpoints should not reset stats, client + // SDKs are responsible for that. + expectProgress(0, 2); + expect(db.select('SELECT * FROM items'), isNotEmpty); + + expect( + pushCheckpointComplete( + '2', + null, + [_bucketChecksum(bucket, 1, checksum: 0)], + ), + isTrue, + ); + expectProgress(0, 2); + + db.execute(''' +UPDATE ps_buckets SET count_since_last = 0, count_at_last = ?1->name + WHERE ?1->name IS NOT NULL +''', [ + json.encode({bucket: 2}), + ]); + expectProgress(2, 0); + + // Run another iteration of this + pushSyncData(bucket, '3', 'row-3', 'PUT', {'col': 'hi'}); + expectProgress(2, 1); + db.execute(''' +UPDATE ps_buckets SET count_since_last = 0, count_at_last = ?1->name + WHERE ?1->name IS NOT NULL +''', [ + json.encode({bucket: 3}), + ]); + expectProgress(3, 0); + }); + }); +} + +Object? _bucketChecksum(String bucket, int prio, {int checksum = 0}) { + return {'bucket': bucket, 'priority': prio, 'checksum': checksum}; +} + +const _schema = { + 'tables': [ + { + 'name': 'items', + 'columns': [ + {'name': 'col', 'type': 'text'} + ], + } + ] +}; + +const _bucketDescriptions = { + 'prio0': {'priority': 0}, + 'prio1': {'priority': 1}, + 'prio2': {'priority': 2}, + 'prio3': {'priority': 3}, +}; diff --git a/dart/test/sync_local_performance_test.dart b/dart/test/sync_local_performance_test.dart new file mode 100644 index 0000000..9441138 --- /dev/null +++ b/dart/test/sync_local_performance_test.dart @@ -0,0 +1,327 @@ +import 'dart:convert'; + +import 'package:sqlite3/common.dart'; +import 'package:sqlite3/sqlite3.dart'; +import 'package:test/test.dart'; + +import 'utils/native_test_utils.dart'; +import 'utils/tracking_vfs.dart'; +import './schema_test.dart' show schema; + +// These test how many filesystem reads and writes are performed during sync_local. +// The real world performane of filesystem operations depend a lot on the specific system. +// For example, on native desktop systems, the performance of temporary filesystem storage could +// be close to memory performance. However, on web and mobile, (temporary) filesystem operations +// could drastically slow down performance. So rather than only testing the real time for these +// queries, we count the number of filesystem operations. +void testFilesystemOperations( + {bool unique = true, + int count = 200000, + int alreadyApplied = 10000, + int buckets = 10, + bool rawQueries = false}) { + late TrackingFileSystem vfs; + late CommonDatabase db; + final skip = rawQueries == false ? 'For manual query testing only' : null; + + setUpAll(() { + loadExtension(); + }); + + setUp(() async { + // Needs an unique name per test file to avoid concurrency issues + vfs = new TrackingFileSystem( + parent: new InMemoryFileSystem(), name: 'perf-test-vfs'); + sqlite3.registerVirtualFileSystem(vfs, makeDefault: false); + db = openTestDatabase(vfs: vfs, fileName: 'test.db'); + }); + + tearDown(() { + db.dispose(); + sqlite3.unregisterVirtualFileSystem(vfs); + }); + + setUp(() { + // Optional: set a custom cache size - it affects the number of filesystem operations. + // db.execute('PRAGMA cache_size=-50000'); + db.execute('SELECT powersync_replace_schema(?)', [json.encode(schema)]); + // Generate dummy data + // We can replace this with actual similated download operations later + db.execute(''' +BEGIN TRANSACTION; + +WITH RECURSIVE generate_rows(n) AS ( + SELECT 1 + UNION ALL + SELECT n + 1 FROM generate_rows WHERE n < $count +) +INSERT INTO ps_oplog (bucket, op_id, row_type, row_id, key, data, hash) +SELECT + (n % $buckets), -- Generate n different buckets + n, + 'assets', + ${unique ? 'uuid()' : "'duplicated_id'"}, + uuid(), + '{"description": "' || n || '", "make": "test", "model": "this is just filler data. this is just filler data. this is just filler data. this is just filler data. this is just filler data. this is just filler data. this is just filler data. "}', + (n * 17) % 1000000000 -- Some pseudo-random hash + +FROM generate_rows; + +WITH RECURSIVE generate_bucket_rows(n) AS ( + SELECT 1 + UNION ALL + SELECT n + 1 FROM generate_bucket_rows WHERE n < $buckets +) +INSERT INTO ps_buckets (id, name, last_applied_op) +SELECT + (n % $buckets), + 'bucket' || n, + $alreadyApplied -- simulate a percentage of operations previously applied + +FROM generate_bucket_rows; + +COMMIT; +'''); + // Enable this to see stats for initial data generation + // print('init stats: ${vfs.stats()}'); + + vfs.clearStats(); + }); + + test('sync_local (full)', () { + var timer = Stopwatch()..start(); + db.select('insert into powersync_operations(op, data) values(?, ?)', + ['sync_local', '']); + print('${timer.elapsed.inMilliseconds}ms ${vfs.stats()}'); + + // These are fairly generous limits, to catch significant regressions only. + expect(vfs.tempWrites, lessThan(count / 50)); + expect(timer.elapsed, + lessThan(Duration(milliseconds: 100 + (count / 50).round()))); + }); + + test('sync_local (partial)', () { + var timer = Stopwatch()..start(); + db.select('insert into powersync_operations(op, data) values(?, ?)', [ + 'sync_local', + jsonEncode({ + 'buckets': ['bucket0', 'bucket3', 'bucket4', 'bucket5', 'bucket6'], + 'priority': 2 + }) + ]); + print('${timer.elapsed.inMilliseconds}ms ${vfs.stats()}'); + expect(vfs.tempWrites, lessThan(count / 50)); + expect(timer.elapsed, + lessThan(Duration(milliseconds: 100 + (count / 50).round()))); + }); + + // The tests below are for comparing different queries, not run as part of the + // standard test suite. + + test('sync_local new query', () { + // This is the query we're using now. + // This query only uses a single TEMP B-TREE for the GROUP BY operation, + // leading to fairly efficient execution. + + // QUERY PLAN + // |--CO-ROUTINE updated_rows + // | `--COMPOUND QUERY + // | |--LEFT-MOST SUBQUERY + // | | |--SCAN buckets + // | | `--SEARCH b USING INDEX ps_oplog_opid (bucket=? AND op_id>?) + // | `--UNION ALL + // | `--SCAN ps_updated_rows + // |--SCAN b + // |--USE TEMP B-TREE FOR GROUP BY + // `--CORRELATED SCALAR SUBQUERY 3 + // `--SEARCH r USING INDEX ps_oplog_row (row_type=? AND row_id=?) + // + // For details on the max(r.op_id) clause, see: + // https://sqlite.org/lang_select.html#bare_columns_in_an_aggregate_query + // > If there is exactly one min() or max() aggregate in the query, then all bare columns in the result + // > set take values from an input row which also contains the minimum or maximum. + + var timer = Stopwatch()..start(); + final q = ''' +-- 1. Filter oplog by the ops added but not applied yet (oplog b). +-- We do not do any DISTINCT operation here, since that introduces a temp b-tree. +-- We filter out duplicates using the GROUP BY below. +WITH updated_rows AS ( + SELECT b.row_type, b.row_id FROM ps_buckets AS buckets + CROSS JOIN ps_oplog AS b ON b.bucket = buckets.id + AND (b.op_id > buckets.last_applied_op) + UNION ALL SELECT row_type, row_id FROM ps_updated_rows +) + +-- 2. Find *all* current ops over different buckets for those objects (oplog r). +SELECT + b.row_type, + b.row_id, + ( + -- 3. For each unique row, select the data from the latest oplog entry. + -- The max(r.op_id) clause is used to select the latest oplog entry. + -- The iif is to avoid the max(r.op_id) column ending up in the results. + SELECT iif(max(r.op_id), r.data, null) + FROM ps_oplog r + WHERE r.row_type = b.row_type + AND r.row_id = b.row_id + + ) as data + FROM updated_rows b + -- Group for (2) + GROUP BY b.row_type, b.row_id; +'''; + db.select(q); + print('${timer.elapsed.inMilliseconds}ms ${vfs.stats()}'); + }, skip: skip); + + test('old query', () { + // This query used a TEMP B-TREE for the first part of finding unique updated rows, + // then another TEMP B-TREE for the second GROUP BY. This redundant B-TREE causes + // a lot of temporary storage overhead. + + // QUERY PLAN + // |--CO-ROUTINE updated_rows + // | `--COMPOUND QUERY + // | |--LEFT-MOST SUBQUERY + // | | |--SCAN buckets + // | | `--SEARCH b USING INDEX ps_oplog_opid (bucket=? AND op_id>?) + // | `--UNION USING TEMP B-TREE + // | `--SCAN ps_updated_rows + // |--SCAN b + // |--SEARCH r USING INDEX ps_oplog_row (row_type=? AND row_id=?) LEFT-JOIN + // `--USE TEMP B-TREE FOR GROUP BY + + var timer = Stopwatch()..start(); + final q = ''' +WITH updated_rows AS ( + SELECT DISTINCT b.row_type, b.row_id FROM ps_buckets AS buckets + CROSS JOIN ps_oplog AS b ON b.bucket = buckets.id + AND (b.op_id > buckets.last_applied_op) + UNION SELECT row_type, row_id FROM ps_updated_rows +) +SELECT b.row_type as type, + b.row_id as id, + r.data as data, + count(r.bucket) as buckets, + max(r.op_id) as op_id +FROM updated_rows b + LEFT OUTER JOIN ps_oplog AS r + ON r.row_type = b.row_type + AND r.row_id = b.row_id +GROUP BY b.row_type, b.row_id; +'''; + db.select(q); + print('${timer.elapsed.inMilliseconds}ms ${vfs.stats()}'); + }, skip: skip); + + test('group_by query', () { + // This is similar to the new query, but uses a GROUP BY .. LIMIT 1 clause instead of the max(op_id) hack. + // It is similar in the number of filesystem operations, but slightly slower in real time. + + // QUERY PLAN + // |--CO-ROUTINE updated_rows + // | `--COMPOUND QUERY + // | |--LEFT-MOST SUBQUERY + // | | |--SCAN buckets + // | | `--SEARCH b USING INDEX ps_oplog_opid (bucket=? AND op_id>?) + // | `--UNION ALL + // | `--SCAN ps_updated_rows + // |--SCAN b + // |--USE TEMP B-TREE FOR GROUP BY + // `--CORRELATED SCALAR SUBQUERY 3 + // |--SEARCH r USING INDEX ps_oplog_row (row_type=? AND row_id=?) + // `--USE TEMP B-TREE FOR ORDER BY + + var timer = Stopwatch()..start(); + final q = ''' +WITH updated_rows AS ( + SELECT b.row_type, b.row_id FROM ps_buckets AS buckets + CROSS JOIN ps_oplog AS b ON b.bucket = buckets.id + AND (b.op_id > buckets.last_applied_op) + UNION ALL SELECT row_type, row_id FROM ps_updated_rows +) + +SELECT + b.row_type, + b.row_id, + ( + SELECT r.data FROM ps_oplog r + WHERE r.row_type = b.row_type + AND r.row_id = b.row_id + ORDER BY r.op_id DESC + LIMIT 1 + + ) as data + FROM updated_rows b + GROUP BY b.row_type, b.row_id; +'''; + db.select(q); + print('${timer.elapsed.inMilliseconds}ms ${vfs.stats()}'); + }, skip: skip); + + test('full scan query', () { + // This is a nice alternative for initial sync or resyncing large amounts of data. + // This is very efficient for reading all data, but not for incremental updates. + + // QUERY PLAN + // |--SCAN r USING INDEX ps_oplog_row + // |--CORRELATED SCALAR SUBQUERY 1 + // | `--SEARCH ps_buckets USING INTEGER PRIMARY KEY (rowid=?) + // `--CORRELATED SCALAR SUBQUERY 1 + // `--SEARCH ps_buckets USING INTEGER PRIMARY KEY (rowid=?) + + var timer = Stopwatch()..start(); + final q = ''' +SELECT r.row_type as type, + r.row_id as id, + r.data as data, + max(r.op_id) as op_id, + sum((select 1 from ps_buckets where ps_buckets.id = r.bucket and r.op_id > ps_buckets.last_applied_op)) as buckets + +FROM ps_oplog r +GROUP BY r.row_type, r.row_id +HAVING buckets > 0; +'''; + db.select(q); + print('${timer.elapsed.inMilliseconds}ms ${vfs.stats()}'); + }, skip: skip); +} + +main() { + group('test filesystem operations with unique ids', () { + testFilesystemOperations( + unique: true, + count: 500000, + alreadyApplied: 10000, + buckets: 10, + rawQueries: false); + }); + group('test filesytem operations with duplicate ids', () { + // If this takes more than a couple of milliseconds to complete, there is a performance bug + testFilesystemOperations( + unique: false, + count: 500000, + alreadyApplied: 1000, + buckets: 10, + rawQueries: false); + }); + + group('test filesystem operations with a small number of changes', () { + testFilesystemOperations( + unique: true, + count: 100000, + alreadyApplied: 95000, + buckets: 10, + rawQueries: false); + }); + + group('test filesystem operations with a large number of buckets', () { + testFilesystemOperations( + unique: true, + count: 100000, + alreadyApplied: 10000, + buckets: 1000, + rawQueries: false); + }); +} diff --git a/dart/test/sync_test.dart b/dart/test/sync_test.dart index 39b2bb7..fe78666 100644 --- a/dart/test/sync_test.dart +++ b/dart/test/sync_test.dart @@ -1,16 +1,30 @@ import 'dart:convert'; +import 'dart:io'; +import 'dart:typed_data'; +import 'package:bson/bson.dart'; import 'package:fake_async/fake_async.dart'; import 'package:file/local.dart'; +import 'package:meta/meta.dart'; import 'package:sqlite3/common.dart'; import 'package:sqlite3/sqlite3.dart'; import 'package:sqlite3_test/sqlite3_test.dart'; import 'package:test/test.dart'; +import 'package:path/path.dart'; import 'utils/native_test_utils.dart'; +@isTest +void syncTest(String description, void Function(FakeAsync controller) body) { + return test(description, () { + // Give each test the same starting time to make goldens easier to compare. + fakeAsync(body, initialTime: DateTime.utc(2025, 3, 1, 10)); + }); +} + void main() { - final vfs = TestSqliteFileSystem(fs: const LocalFileSystem()); + final vfs = + TestSqliteFileSystem(fs: const LocalFileSystem(), name: 'vfs-sync-test'); setUpAll(() { loadExtension(); @@ -18,307 +32,652 @@ void main() { }); tearDownAll(() => sqlite3.unregisterVirtualFileSystem(vfs)); - group('sync tests', () { - late CommonDatabase db; + group('text lines', () { + _syncTests(vfs: vfs, isBson: false); + }); - setUp(() async { - db = openTestDatabase(vfs) - ..select('select powersync_init();') - ..select('select powersync_replace_schema(?)', [json.encode(_schema)]); - }); + group('bson lines', () { + _syncTests(vfs: vfs, isBson: true); + }); +} - tearDown(() { - db.dispose(); - }); +void _syncTests({ + required VirtualFileSystem vfs, + required bool isBson, +}) { + late CommonDatabase db; + late SyncLinesGoldenTest matcher; + + List invokeControlRaw(String operation, Object? data) { + db.execute('begin'); + final [row] = + db.select('SELECT powersync_control(?, ?)', [operation, data]); + db.execute('commit'); + return jsonDecode(row.columnAt(0)); + } + + List invokeControl(String operation, Object? data) { + if (matcher.enabled) { + // Trace through golden matcher + return matcher.invoke(operation, data); + } else { + return invokeControlRaw(operation, data); + } + } - void pushSyncData( - String bucket, - String opId, - String rowId, - Object op, - Object? data, { - Object? descriptions = _bucketDescriptions, - }) { - final encoded = json.encode({ - 'buckets': [ + setUp(() async { + db = openTestDatabase(vfs: vfs) + ..select('select powersync_init();') + ..select('select powersync_replace_schema(?)', [json.encode(_schema)]) + ..execute('update ps_kv set value = ?2 where key = ?1', + ['client_id', 'test-test-test-test']); + + matcher = SyncLinesGoldenTest(isBson, invokeControlRaw); + }); + + tearDown(() { + matcher.finish(); + db.dispose(); + }); + + List syncLine(Object? line) { + if (isBson) { + final serialized = BsonCodec.serialize(line).byteList; + // print(serialized.asRustByteString); + return invokeControl('line_binary', serialized); + } else { + return invokeControl('line_text', jsonEncode(line)); + } + } + + List pushSyncData( + String bucket, String opId, String rowId, Object op, Object? data, + {int checksum = 0}) { + return syncLine({ + 'data': { + 'bucket': bucket, + 'has_more': false, + 'after': null, + 'next_after': null, + 'data': [ { - 'bucket': bucket, - 'data': [ - { - 'op_id': opId, - 'op': op, - 'object_type': 'items', - 'object_id': rowId, - 'checksum': 0, - 'data': data, - } - ], + 'op_id': opId, + 'op': op, + 'object_type': 'items', + 'object_id': rowId, + 'checksum': checksum, + 'data': json.encode(data), } ], - if (descriptions != null) 'descriptions': descriptions, - }); + }, + }); + } + + List pushCheckpoint( + {int lastOpId = 1, List buckets = const []}) { + return syncLine({ + 'checkpoint': { + 'last_op_id': '$lastOpId', + 'write_checkpoint': null, + 'buckets': buckets, + }, + }); + } - db.execute('insert into powersync_operations (op, data) VALUES (?, ?);', - ['save', encoded]); - } + List pushCheckpointComplete({int? priority, String lastOpId = '1'}) { + return syncLine({ + priority == null ? 'checkpoint_complete' : 'partial_checkpoint_complete': + { + 'last_op_id': lastOpId, + if (priority != null) 'priority': priority, + }, + }); + } - bool pushCheckpointComplete( - String lastOpId, String? writeCheckpoint, List checksums, - {int? priority}) { - final [row] = db.select('select powersync_validate_checkpoint(?) as r;', [ + ResultSet fetchRows() { + return db.select('select * from items'); + } + + group('goldens', () { + syncTest('starting stream', (_) { + matcher.load('starting_stream'); + invokeControl( + 'start', json.encode({ - 'last_op_id': lastOpId, - 'write_checkpoint': writeCheckpoint, + 'parameters': {'foo': 'bar'} + }), + ); + }); + + syncTest('simple sync iteration', (_) { + matcher.load('simple_iteration'); + invokeControl('start', null); + + syncLine({ + 'checkpoint': { + 'last_op_id': '1', + 'write_checkpoint': null, 'buckets': [ - for (final cs in checksums.cast>()) - if (priority == null || cs['priority'] <= priority) cs + { + 'bucket': 'a', + 'checksum': 0, + 'priority': 3, + 'count': 1, + } ], - 'priority': priority, - }) - ]); + }, + }); + syncLine({'token_expires_in': 60}); + pushSyncData('a', '1', '1', 'PUT', {'col': 'hi'}); - final decoded = json.decode(row['r']); - if (decoded['valid'] != true) { - fail(row['r']); - } + syncLine({ + 'checkpoint_complete': {'last_op_id': '1'}, + }); - db.execute( - 'UPDATE ps_buckets SET last_op = ? WHERE name IN (SELECT json_each.value FROM json_each(?))', - [ - lastOpId, - json.encode(checksums.map((e) => (e as Map)['bucket']).toList()) - ], - ); + syncLine({'token_expires_in': 10}); + }); + }); - db.execute('INSERT INTO powersync_operations(op, data) VALUES (?, ?)', [ - 'sync_local', - priority != null - ? jsonEncode({ - 'priority': priority, - 'buckets': [ - for (final cs in checksums.cast>()) - if (cs['priority'] <= priority) cs['bucket'] - ], - }) - : null, - ]); - return db.lastInsertRowId == 1; - } + test('does not publish until reaching checkpoint', () { + invokeControl('start', null); + pushCheckpoint(buckets: priorityBuckets); + expect(fetchRows(), isEmpty); + db.execute("insert into items (id, col) values ('local', 'data');"); - ResultSet fetchRows() { - return db.select('select * from items'); - } + pushSyncData('prio1', '1', 'row-0', 'PUT', {'col': 'hi'}); - test('does not publish until reaching checkpoint', () { - expect(fetchRows(), isEmpty); - pushSyncData('prio1', '1', 'row-0', 'PUT', {'col': 'hi'}); - expect(fetchRows(), isEmpty); + pushCheckpointComplete(); + expect(fetchRows(), [ + {'id': 'local', 'col': 'data'} + ]); + }); - expect( - pushCheckpointComplete( - '1', null, [_bucketChecksum('prio1', 1, checksum: 0)]), - isTrue); - expect(fetchRows(), [ - {'id': 'row-0', 'col': 'hi'} - ]); - }); + test('publishes with local data for prio=0 buckets', () { + invokeControl('start', null); + pushCheckpoint(buckets: priorityBuckets); + expect(fetchRows(), isEmpty); + db.execute("insert into items (id, col) values ('local', 'data');"); - test('does not publish with pending local data', () { - expect(fetchRows(), isEmpty); - db.execute("insert into items (id, col) values ('local', 'data');"); - expect(fetchRows(), isNotEmpty); + pushSyncData('prio0', '1', 'row-0', 'PUT', {'col': 'hi'}); - pushSyncData('prio1', '1', 'row-0', 'PUT', {'col': 'hi'}); - expect( - pushCheckpointComplete( - '1', null, [_bucketChecksum('prio1', 1, checksum: 0)]), - isFalse); - expect(fetchRows(), [ - {'id': 'local', 'col': 'data'} - ]); - }); + pushCheckpointComplete(priority: 0); + expect(fetchRows(), [ + {'id': 'local', 'col': 'data'}, + {'id': 'row-0', 'col': 'hi'}, + ]); + }); - test('publishes with local data for prio=0 buckets', () { - expect(fetchRows(), isEmpty); - db.execute("insert into items (id, col) values ('local', 'data');"); - expect(fetchRows(), isNotEmpty); + test('does not publish with pending local data', () { + invokeControl('start', null); + pushCheckpoint(buckets: priorityBuckets); + db.execute("insert into items (id, col) values ('local', 'data');"); + expect(fetchRows(), isNotEmpty); - pushSyncData('prio0', '1', 'row-0', 'PUT', {'col': 'hi'}); - expect( - pushCheckpointComplete( - '1', - null, - [_bucketChecksum('prio0', 0, checksum: 0)], - priority: 0, - ), - isTrue, + pushCheckpoint(buckets: priorityBuckets); + pushSyncData('prio1', '1', 'row-0', 'PUT', {'col': 'hi'}); + pushCheckpointComplete(); + + expect(fetchRows(), [ + {'id': 'local', 'col': 'data'} + ]); + }); + + test('can publish partial checkpoints under different priorities', () { + invokeControl('start', null); + pushCheckpoint(buckets: priorityBuckets); + for (var i = 0; i < 4; i++) { + pushSyncData('prio$i', '1', 'row-$i', 'PUT', {'col': '$i'}); + } + + expect(fetchRows(), isEmpty); + + // Simulate a partial checkpoint complete for each of the buckets. + for (var i = 0; i < 4; i++) { + pushCheckpointComplete( + priority: i, ); + expect(fetchRows(), [ - {'id': 'local', 'col': 'data'}, - {'id': 'row-0', 'col': 'hi'}, + for (var j = 0; j <= i; j++) {'id': 'row-$j', 'col': '$j'}, ]); - }); - test('can publish partial checkpoints under different priorities', () { - for (var i = 0; i < 4; i++) { - pushSyncData('prio$i', '1', 'row-$i', 'PUT', {'col': '$i'}); - } - expect(fetchRows(), isEmpty); + expect(db.select('select 1 from ps_sync_state where priority = ?', [i]), + isNotEmpty); + // A sync at this priority includes all higher priorities too, so they + // should be cleared. + expect(db.select('select 1 from ps_sync_state where priority < ?', [i]), + isEmpty); + } + }); + + syncTest('can sync multiple times', (controller) { + invokeControl('start', null); - // Simulate a partial checkpoint complete for each of the buckets. - for (var i = 0; i < 4; i++) { + for (var i = 0; i < 10; i++) { + pushCheckpoint(buckets: priorityBuckets); + + for (var prio in const [1, 2, 3, null]) { + pushCheckpointComplete(priority: prio); + + // Make sure there's only a single row in last_synced_at expect( - pushCheckpointComplete( - '1', - null, + db.select( + "SELECT datetime(last_synced_at) AS last_synced_at FROM ps_sync_state WHERE priority = ?", + [prio ?? 2147483647]), + [ + {'last_synced_at': '2025-03-01 ${10 + i}:00:00'} + ], + ); + + if (prio == null) { + expect( + db.select( + "SELECT datetime(powersync_last_synced_at()) AS last_synced_at"), [ - for (var j = 0; j <= 4; j++) - _bucketChecksum( - 'prio$j', - j, - // Give buckets outside of the current priority a wrong - // checksum. They should not be validated yet. - checksum: j <= i ? 0 : 1234, - ), + {'last_synced_at': '2025-03-01 ${10 + i}:00:00'} ], - priority: i, - ), - isTrue, - ); + ); + } + } - expect(fetchRows(), [ - for (var j = 0; j <= i; j++) {'id': 'row-$j', 'col': '$j'}, - ]); + controller.elapse(const Duration(hours: 1)); + } + }); - expect(db.select('select 1 from ps_sync_state where priority = ?', [i]), - isNotEmpty); - // A sync at this priority includes all higher priorities too, so they - // should be cleared. - expect(db.select('select 1 from ps_sync_state where priority < ?', [i]), - isEmpty); - } + test('clearing database clears sync status', () { + invokeControl('start', null); + pushCheckpoint(buckets: priorityBuckets); + pushCheckpointComplete(); + + expect(db.select('SELECT powersync_last_synced_at() AS r').single, + {'r': isNotNull}); + expect(db.select('SELECT priority FROM ps_sync_state').single, + {'priority': 2147483647}); + + db.execute('SELECT powersync_clear(0)'); + expect(db.select('SELECT powersync_last_synced_at() AS r').single, + {'r': isNull}); + expect(db.select('SELECT * FROM ps_sync_state'), hasLength(0)); + }); + + test('persists download progress', () { + const bucket = 'bkt'; + void expectProgress(int atLast, int sinceLast) { + final [row] = db.select( + 'SELECT count_at_last, count_since_last FROM ps_buckets WHERE name = ?', + [bucket], + ); + final [actualAtLast, actualSinceLast] = row.values; + + expect(actualAtLast, atLast, reason: 'count_at_last mismatch'); + expect(actualSinceLast, sinceLast, reason: 'count_since_last mismatch'); + } + + invokeControl('start', null); + pushCheckpoint(buckets: [bucketDescription(bucket, count: 2)]); + pushCheckpointComplete(); + + pushSyncData(bucket, '1', 'row-0', 'PUT', {'col': 'hi'}); + expectProgress(0, 1); + + pushSyncData(bucket, '1', 'row-1', 'PUT', {'col': 'hi again'}); + expectProgress(0, 2); + + pushCheckpointComplete(lastOpId: '2'); + expectProgress(2, 0); + }); + + test('deletes old buckets', () { + for (final name in ['one', 'two', 'three', r'$local']) { + db.execute('INSERT INTO ps_buckets (name) VALUES (?)', [name]); + } + + expect( + invokeControl('start', null), + contains( + containsPair( + 'EstablishSyncStream', + containsPair('request', containsPair('buckets', hasLength(3))), + ), + ), + ); + + syncLine({ + 'checkpoint': { + 'last_op_id': '1', + 'write_checkpoint': null, + 'buckets': [ + { + 'bucket': 'one', + 'checksum': 0, + 'priority': 3, + 'count': 1, + } + ], + }, }); - test('can sync multiple times', () { - fakeAsync((controller) { - for (var i = 0; i < 10; i++) { - for (var prio in const [1, 2, 3, null]) { - pushCheckpointComplete('1', null, [], priority: prio); - - // Make sure there's only a single row in last_synced_at - expect( - db.select( - "SELECT datetime(last_synced_at, 'localtime') AS last_synced_at FROM ps_sync_state WHERE priority = ?", - [prio ?? 2147483647]), - [ - {'last_synced_at': '2025-03-01 ${10 + i}:00:00'} - ], - ); - - if (prio == null) { - expect( - db.select( - "SELECT datetime(powersync_last_synced_at(), 'localtime') AS last_synced_at"), - [ - {'last_synced_at': '2025-03-01 ${10 + i}:00:00'} - ], - ); + // Should delete the old buckets two and three + expect(db.select('select name from ps_buckets order by id'), [ + {'name': 'one'}, + {'name': r'$local'} + ]); + }); + + if (isBson) { + test('can parse checksums from JS numbers', () { + invokeControl('start', null); + pushCheckpoint(buckets: [bucketDescription('global[]')]); + + syncLine({ + 'data': { + 'bucket': 'a', + 'has_more': false, + 'after': null, + 'next_after': null, + 'data': [ + { + 'op_id': '1', + 'op': 'PUT', + 'object_type': 'items', + 'object_id': 'id', + 'checksum': 3573495687.0, + 'data': '{}', } + ], + }, + }); + }); + } + + group('progress', () { + Map? progress = null; + var lastOpId = 0; + + setUp(() { + lastOpId = 0; + return progress = null; + }); + + (int, int) totalProgress() { + return progress!.values.downloadAndTargetCount(); + } + + (int, int) priorityProgress(int priority) { + return progress!.values + .where((e) => e.priority <= priority) + .downloadAndTargetCount(); + } + + void applyInstructions(List instructions) { + for (final instruction in instructions.cast()) { + if (instruction['UpdateSyncStatus'] case final updateStatus?) { + final downloading = updateStatus['status']['downloading']; + if (downloading == null) { + progress = null; + } else { + progress = { + for (final MapEntry(:key, :value) + in downloading['buckets'].entries) + key: ( + atLast: value['at_last'] as int, + sinceLast: value['since_last'] as int, + targetCount: value['target_count'] as int, + priority: value['priority'] as int, + ), + }; } + } + } + } + + void pushSyncData(String bucket, int amount) { + final instructions = syncLine({ + 'data': { + 'bucket': bucket, + 'has_more': false, + 'after': null, + 'next_after': null, + 'data': [ + for (var i = 0; i < amount; i++) + { + 'op_id': (++lastOpId).toString(), + 'op': 'PUT', + 'object_type': 'items', + 'object_id': '$lastOpId', + 'checksum': 0, + 'data': '{}', + } + ], + }, + }); + + applyInstructions(instructions); + } - controller.elapse(const Duration(hours: 1)); + void addCheckpointComplete({int? priority}) { + applyInstructions( + pushCheckpointComplete(priority: priority, lastOpId: '$lastOpId')); + } + + test('without priorities', () { + applyInstructions(invokeControl('start', null)); + expect(progress, isNull); + + applyInstructions(pushCheckpoint( + buckets: [bucketDescription('a', count: 10)], lastOpId: 10)); + expect(totalProgress(), (0, 10)); + + pushSyncData('a', 10); + expect(totalProgress(), (10, 10)); + + addCheckpointComplete(); + expect(progress, isNull); + + // Emit new data, progress should be 0/2 instead of 10/12 + applyInstructions(syncLine({ + 'checkpoint_diff': { + 'last_op_id': '12', + 'updated_buckets': [ + { + 'bucket': 'a', + 'priority': 3, + 'checksum': 0, + 'count': 12, + 'last_op_id': null + }, + ], + 'removed_buckets': [], + 'write_checkpoint': null, } - }, initialTime: DateTime(2025, 3, 1, 10)); + })); + expect(totalProgress(), (0, 2)); + + pushSyncData('a', 2); + expect(totalProgress(), (2, 2)); + + addCheckpointComplete(); + expect(progress, isNull); }); - test('clearing database clears sync status', () { - pushSyncData('prio1', '1', 'row-0', 'PUT', {'col': 'hi'}); + test('interrupted sync', () { + applyInstructions(invokeControl('start', null)); + applyInstructions(pushCheckpoint( + buckets: [bucketDescription('a', count: 10)], lastOpId: 10)); + expect(totalProgress(), (0, 10)); - expect( - pushCheckpointComplete( - '1', null, [_bucketChecksum('prio1', 1, checksum: 0)]), - isTrue); - expect(db.select('SELECT powersync_last_synced_at() AS r').single, - {'r': isNotNull}); - expect(db.select('SELECT priority FROM ps_sync_state').single, - {'priority': 2147483647}); - - db.execute('SELECT powersync_clear(0)'); - expect(db.select('SELECT powersync_last_synced_at() AS r').single, - {'r': isNull}); - expect(db.select('SELECT * FROM ps_sync_state'), hasLength(0)); + pushSyncData('a', 5); + expect(totalProgress(), (5, 10)); + + // Emulate stream closing + applyInstructions(invokeControl('stop', null)); + expect(progress, isNull); + + applyInstructions(invokeControl('start', null)); + applyInstructions(pushCheckpoint( + buckets: [bucketDescription('a', count: 10)], lastOpId: 10)); + expect(totalProgress(), (5, 10)); + + pushSyncData('a', 5); + expect(totalProgress(), (10, 10)); + addCheckpointComplete(); + expect(progress, isNull); }); - test('tracks download progress', () { - const bucket = 'bkt'; - void expectProgress(int atLast, int sinceLast) { - final [row] = db.select( - 'SELECT count_at_last, count_since_last FROM ps_buckets WHERE name = ?', - [bucket], - ); - final [actualAtLast, actualSinceLast] = row.values; + test('interrupted sync with new checkpoint', () { + applyInstructions(invokeControl('start', null)); + applyInstructions(pushCheckpoint( + buckets: [bucketDescription('a', count: 10)], lastOpId: 10)); + expect(totalProgress(), (0, 10)); + + pushSyncData('a', 5); + expect(totalProgress(), (5, 10)); + + // Emulate stream closing + applyInstructions(invokeControl('stop', null)); + expect(progress, isNull); + + applyInstructions(invokeControl('start', null)); + applyInstructions(pushCheckpoint( + buckets: [bucketDescription('a', count: 12)], lastOpId: 12)); + expect(totalProgress(), (5, 12)); + + pushSyncData('a', 7); + expect(totalProgress(), (12, 12)); + addCheckpointComplete(); + expect(progress, isNull); + }); - expect(actualAtLast, atLast, reason: 'count_at_last mismatch'); - expect(actualSinceLast, sinceLast, reason: 'count_since_last mismatch'); + test('interrupt and defrag', () { + applyInstructions(invokeControl('start', null)); + applyInstructions(pushCheckpoint( + buckets: [bucketDescription('a', count: 10)], lastOpId: 10)); + expect(totalProgress(), (0, 10)); + + pushSyncData('a', 5); + expect(totalProgress(), (5, 10)); + + // Emulate stream closing + applyInstructions(invokeControl('stop', null)); + expect(progress, isNull); + + applyInstructions(invokeControl('start', null)); + // A defrag in the meantime shrank the bucket. + applyInstructions(pushCheckpoint( + buckets: [bucketDescription('a', count: 4)], lastOpId: 14)); + // So we shouldn't report 5/4. + expect(totalProgress(), (0, 4)); + + // This should also reset the persisted progress counters. + final [bucket] = db.select('SELECT * FROM ps_buckets'); + expect(bucket, containsPair('count_since_last', 0)); + expect(bucket, containsPair('count_at_last', 0)); + }); + + test('different priorities', () { + void expectProgress((int, int) prio0, (int, int) prio2) { + expect(priorityProgress(0), prio0); + expect(priorityProgress(1), prio0); + expect(priorityProgress(2), prio2); + expect(totalProgress(), prio2); } - pushSyncData(bucket, '1', 'row-0', 'PUT', {'col': 'hi'}); - expectProgress(0, 1); + applyInstructions(invokeControl('start', null)); + applyInstructions(pushCheckpoint(buckets: [ + bucketDescription('a', count: 5, priority: 0), + bucketDescription('b', count: 5, priority: 2), + ], lastOpId: 10)); + expectProgress((0, 5), (0, 10)); + + pushSyncData('a', 5); + expectProgress((5, 5), (5, 10)); + + pushSyncData('b', 2); + expectProgress((5, 5), (7, 10)); + + // Before syncing b fully, send a new checkpoint + applyInstructions(pushCheckpoint(buckets: [ + bucketDescription('a', count: 8, priority: 0), + bucketDescription('b', count: 6, priority: 2), + ], lastOpId: 14)); + expectProgress((5, 8), (7, 14)); + + pushSyncData('a', 3); + expectProgress((8, 8), (10, 14)); + pushSyncData('b', 4); + expectProgress((8, 8), (14, 14)); + + addCheckpointComplete(); + expect(progress, isNull); + }); + }); - pushSyncData(bucket, '2', 'row-1', 'PUT', {'col': 'hi'}); - expectProgress(0, 2); + group('errors', () { + syncTest('diff without prior checkpoint', (_) { + invokeControl('start', null); expect( - pushCheckpointComplete( - '2', - null, - [_bucketChecksum(bucket, 1, checksum: 0)], - priority: 1, + () => syncLine({ + 'checkpoint_diff': { + 'last_op_id': '1', + 'write_checkpoint': null, + 'updated_buckets': [], + 'removed_buckets': [], + }, + }), + throwsA( + isA().having( + (e) => e.message, + 'message', + contains('checkpoint_diff without previous checkpoint'), + ), ), - isTrue, ); + }); - // Running partial or complete checkpoints should not reset stats, client - // SDKs are responsible for that. - expectProgress(0, 2); - expect(db.select('SELECT * FROM items'), isNotEmpty); + syncTest('checksum mismatch', (_) { + invokeControl('start', null); + + syncLine({ + 'checkpoint': { + 'last_op_id': '1', + 'write_checkpoint': null, + 'buckets': [ + { + 'bucket': 'a', + 'checksum': 1234, + 'priority': 3, + 'count': 1, + } + ], + }, + }); + pushSyncData('a', '1', '1', 'PUT', {'col': 'hi'}, checksum: 4321); + + expect(db.select('SELECT * FROM ps_buckets'), hasLength(1)); expect( - pushCheckpointComplete( - '2', - null, - [_bucketChecksum(bucket, 1, checksum: 0)], - ), - isTrue, + syncLine({ + 'checkpoint_complete': {'last_op_id': '1'}, + }), + [ + { + 'LogLine': { + 'severity': 'WARNING', + 'line': contains( + "Checksums didn't match, failed for: a (expected 0x000004d2, got 0x000010e1 = 0x000010e1 (op) + 0x00000000 (add))") + } + }, + {'CloseSyncStream': {}}, + ], ); - expectProgress(0, 2); - db.execute(''' -UPDATE ps_buckets SET count_since_last = 0, count_at_last = ?1->name - WHERE ?1->name IS NOT NULL -''', [ - json.encode({bucket: 2}), - ]); - expectProgress(2, 0); - - // Run another iteration of this - pushSyncData(bucket, '3', 'row-3', 'PUT', {'col': 'hi'}); - expectProgress(2, 1); - db.execute(''' -UPDATE ps_buckets SET count_since_last = 0, count_at_last = ?1->name - WHERE ?1->name IS NOT NULL -''', [ - json.encode({bucket: 3}), - ]); - expectProgress(3, 0); + // Should delete bucket with checksum mismatch + expect(db.select('SELECT * FROM ps_buckets'), isEmpty); }); }); } -Object? _bucketChecksum(String bucket, int prio, {int checksum = 0}) { - return {'bucket': bucket, 'priority': prio, 'checksum': checksum}; -} - const _schema = { 'tables': [ { @@ -330,9 +689,167 @@ const _schema = { ] }; -const _bucketDescriptions = { - 'prio0': {'priority': 0}, - 'prio1': {'priority': 1}, - 'prio2': {'priority': 2}, - 'prio3': {'priority': 3}, -}; +Object bucketDescription(String name, + {int checksum = 0, int priority = 3, int count = 1}) { + return { + 'bucket': name, + 'checksum': checksum, + 'priority': priority, + 'count': count, + }; +} + +final priorityBuckets = [ + for (var i = 0; i < 4; i++) bucketDescription('prio$i', priority: i) +]; + +typedef BucketProgress = ({ + int priority, + int atLast, + int sinceLast, + int targetCount +}); + +extension on Iterable { + (int, int) downloadAndTargetCount() { + return fold((0, 0), (counters, entry) { + final (downloaded, total) = counters; + + return ( + downloaded + entry.sinceLast, + total + entry.targetCount - entry.atLast + ); + }); + } +} + +extension on Uint8List { + // ignore: unused_element + String get asRustByteString { + final buffer = StringBuffer('b"'); + + for (final byte in this) { + switch (byte) { + case >= 32 && < 127: + buffer.writeCharCode(byte); + default: + // Escape + buffer.write('\\x${byte.toRadixString(16).padLeft(2, '0')}'); + } + } + + buffer.write('"'); + return buffer.toString(); + } +} + +final class SyncLinesGoldenTest { + static bool _update = Platform.environment['UPDATE_GOLDENS'] == '1'; + + final List Function(String operation, Object? data) _invokeControl; + + String? name; + + final bool isBson; + final List expectedLines = []; + final List actualLines = []; + + String get path => join('test', 'goldens', '$name.json'); + + bool get enabled => name != null; + + SyncLinesGoldenTest(this.isBson, this._invokeControl); + + ExpectedSyncLine get _nextExpectation { + return expectedLines[actualLines.length]; + } + + void _checkMismatch(void Function() compare) { + try { + compare(); + } catch (e) { + print( + 'Golden test for sync lines failed, set UPDATE_GOLDENS=1 to update'); + rethrow; + } + } + + void load(String name) { + this.name = name; + final file = File(path); + try { + final loaded = json.decode(file.readAsStringSync()); + for (final entry in loaded) { + expectedLines.add(ExpectedSyncLine.fromJson(entry)); + } + } catch (e) { + if (!_update) { + rethrow; + } + } + } + + List invoke(String operation, Object? data) { + final matchData = switch (data) { + final String s => json.decode(s), + _ => data, + }; + + if (_update) { + final result = _invokeControl(operation, data); + actualLines.add(ExpectedSyncLine(operation, matchData, result)); + return result; + } else { + final expected = _nextExpectation; + if (!isBson) { + // We only want to compare the JSON inputs. We compare outputs + // regardless of the encoding mode. + _checkMismatch(() { + expect(operation, expected.operation); + expect(matchData, expected.data); + }); + } + + final result = _invokeControl(operation, data); + _checkMismatch(() { + expect(result, expected.output); + }); + + actualLines.add(ExpectedSyncLine(operation, matchData, result)); + return result; + } + } + + void finish() { + if (_update && enabled) { + if (!isBson) { + File(path).writeAsStringSync( + JsonEncoder.withIndent(' ').convert(actualLines)); + } + } else { + _checkMismatch( + () => expect(actualLines, hasLength(expectedLines.length))); + } + } +} + +final class ExpectedSyncLine { + final String operation; + final Object? data; + final List output; + + ExpectedSyncLine(this.operation, this.data, this.output); + + factory ExpectedSyncLine.fromJson(Map json) { + return ExpectedSyncLine( + json['operation'] as String, json['data'], json['output'] as List); + } + + Map toJson() { + return { + 'operation': operation, + 'data': data, + 'output': output, + }; + } +} diff --git a/dart/test/utils/native_test_utils.dart b/dart/test/utils/native_test_utils.dart index a6ec244..e65c753 100644 --- a/dart/test/utils/native_test_utils.dart +++ b/dart/test/utils/native_test_utils.dart @@ -26,13 +26,14 @@ void applyOpenOverride() { }); } -CommonDatabase openTestDatabase([VirtualFileSystem? vfs]) { +CommonDatabase openTestDatabase( + {VirtualFileSystem? vfs, String fileName = ':memory:'}) { applyOpenOverride(); if (!didLoadExtension) { loadExtension(); } - return sqlite3.open(':memory:', vfs: vfs?.name); + return sqlite3.open(fileName, vfs: vfs?.name); } void loadExtension() { diff --git a/dart/test/utils/tracking_vfs.dart b/dart/test/utils/tracking_vfs.dart new file mode 100644 index 0000000..86c2707 --- /dev/null +++ b/dart/test/utils/tracking_vfs.dart @@ -0,0 +1,118 @@ +import 'dart:typed_data'; + +import 'package:sqlite3/sqlite3.dart'; + +final class TrackingFileSystem extends BaseVirtualFileSystem { + BaseVirtualFileSystem parent; + int tempReads = 0; + int tempWrites = 0; + int dataReads = 0; + int dataWrites = 0; + + TrackingFileSystem({super.name = 'tracking', required this.parent}); + + @override + int xAccess(String path, int flags) { + return parent.xAccess(path, flags); + } + + @override + void xDelete(String path, int syncDir) { + parent.xDelete(path, syncDir); + } + + @override + String xFullPathName(String path) { + return parent.xFullPathName(path); + } + + @override + XOpenResult xOpen(Sqlite3Filename path, int flags) { + final result = parent.xOpen(path, flags); + return ( + outFlags: result.outFlags, + file: TrackingFile( + result.file, this, flags & SqlFlag.SQLITE_OPEN_DELETEONCLOSE != 0), + ); + } + + @override + void xSleep(Duration duration) {} + + String stats() { + return "Reads: $dataReads + $tempReads | Writes: $dataWrites + $tempWrites"; + } + + void clearStats() { + tempReads = 0; + tempWrites = 0; + dataReads = 0; + dataWrites = 0; + } +} + +class TrackingFile implements VirtualFileSystemFile { + final TrackingFileSystem vfs; + final VirtualFileSystemFile parentFile; + final bool deleteOnClose; + + TrackingFile(this.parentFile, this.vfs, this.deleteOnClose); + + @override + void xWrite(Uint8List buffer, int fileOffset) { + if (deleteOnClose) { + vfs.tempWrites++; + } else { + vfs.dataWrites++; + } + parentFile.xWrite(buffer, fileOffset); + } + + @override + void xRead(Uint8List buffer, int offset) { + if (deleteOnClose) { + vfs.tempReads++; + } else { + vfs.dataReads++; + } + parentFile.xRead(buffer, offset); + } + + @override + int xCheckReservedLock() { + return parentFile.xCheckReservedLock(); + } + + @override + void xClose() { + return parentFile.xClose(); + } + + @override + int xFileSize() { + return parentFile.xFileSize(); + } + + @override + void xLock(int mode) { + return parentFile.xLock(mode); + } + + @override + void xSync(int flags) { + return parentFile.xSync(flags); + } + + @override + void xTruncate(int size) { + return parentFile.xTruncate(size); + } + + @override + void xUnlock(int mode) { + return parentFile.xUnlock(mode); + } + + @override + int get xDeviceCharacteristics => parentFile.xDeviceCharacteristics; +} diff --git a/docs/sync.md b/docs/sync.md new file mode 100644 index 0000000..b24684b --- /dev/null +++ b/docs/sync.md @@ -0,0 +1,82 @@ +## Sync interface + +The core extension implements the state machine and necessary SQL handling to decode and apply +sync line sent from a PowerSync service instance. + +After registering the PowerSync extension, this client is available through the `powersync_control` +function, which takes two arguments: A command (text), and a payload (text, blob, or null). +The function should always be called in a transaction. + +The following commands are supported: + +1. `start`: Payload is a JSON-encoded object. This requests the client to start a sync iteration. + The payload can either be `null` or an JSON object with: + - An optional `parameters: Record` entry, specifying parameters to include in the request + to the sync service. +2. `stop`: No payload, requests the current sync iteration (if any) to be shut down. +3. `line_text`: Payload is a serialized JSON object received from the sync service. +4. `line_binary`: Payload is a BSON-encoded object received from the sync service. +5. `refreshed_token`: Notify the sync client that the JWT used to authenticate to the PowerSync service has + changed. + - The client will emit an instruction to stop the current stream, clients should restart by sending another `start` + command. +6. `completed_upload`: Notify the sync implementation that all local changes have been uploaded. + +`powersync_control` returns a JSON-encoded array of instructions for the client: + +```typescript +type Instruction = { LogLine: LogLine } + | { UpdateSyncStatus: UpdateSyncStatus } + | { EstablishSyncStream: EstablishSyncStream } + | { FetchCredentials: FetchCredentials } + // Close a connection previously started after EstablishSyncStream + | { CloseSyncStream: {} } + // For the Dart web client, flush the (otherwise non-durable) file system. + | { FlushFileSystem: {} } + // Notify clients that a checkpoint was completed. Clients can clear the + // download error state in response to this. + | { DidCompleteSync: {} } + +interface LogLine { + severity: 'DEBUG' | 'INFO' | 'WARNING', + line: String, +} + +// Instructs client SDKs to open a connection to the sync service. +interface EstablishSyncStream { + request: any // The JSON-encoded StreamingSyncRequest to send to the sync service +} + +// Instructs SDKS to update the downloading state of their SyncStatus. +interface UpdateSyncStatus { + connected: boolean, + connecting: boolean, + priority_status: [], + downloading: null | DownloadProgress, +} + +// Instructs SDKs to refresh credentials from the backend connector. +// They don't necessary have to close the connection, a CloseSyncStream instruction +// will be sent when the token has already expired. +interface FetchCredentials { + // Set as an option in case fetching and prefetching should be handled differently. + did_expire: boolean +} + +interface SyncPriorityStatus { + priority: int, + last_synced_at: null | int, + has_synced: null | boolean, +} + +interface DownloadProgress { + buckets: Record +} + +interface BucketProgress { + priority: int, + at_last: int, + since_last: int, + target_count: int +} +``` diff --git a/rust-toolchain.toml b/rust-toolchain.toml index e418b22..5d54722 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-05-18" +channel = "nightly-2025-04-15" diff --git a/tool/build_wasm.sh b/tool/build_wasm.sh index da505e7..f40d8ed 100755 --- a/tool/build_wasm.sh +++ b/tool/build_wasm.sh @@ -1,5 +1,6 @@ #!/bin/bash set -e +emcc --version # Normal build # target/wasm32-unknown-emscripten/wasm/powersync.wasm @@ -31,13 +32,13 @@ cp "target/wasm32-unknown-emscripten/wasm_asyncify/powersync.wasm" "libpowersync # Static lib. # Works for both sync and asyncify builds. # Works for both emscripten and wasi. -# target/wasm32-wasi/wasm/libpowersync.a +# target/wasm32-wasip1/wasm/libpowersync.a cargo build \ -p powersync_loadable \ --profile wasm \ --no-default-features \ --features "powersync_core/static powersync_core/omit_load_extension sqlite_nostd/omit_load_extension" \ -Z build-std=panic_abort,core,alloc \ - --target wasm32-wasi + --target wasm32-wasip1 -cp "target/wasm32-wasi/wasm/libpowersync.a" "libpowersync-wasm.a" +cp "target/wasm32-wasip1/wasm/libpowersync.a" "libpowersync-wasm.a"