diff --git a/.github/actions/setup-rust/action.yml b/.github/actions/setup-rust/action.yml index 15a4b6888ca..62b6f114b01 100644 --- a/.github/actions/setup-rust/action.yml +++ b/.github/actions/setup-rust/action.yml @@ -47,16 +47,25 @@ runs: targets: "${{ inputs.targets }}" components: "${{ inputs.components }}" + - name: Configure sccache timeout + if: inputs.enable-sccache == 'true' + shell: bash + run: | + mkdir -p ~/.config/sccache + echo 'server_startup_timeout_ms = 15000' > ~/.config/sccache/config + - name: Rust Compile Cache if: inputs.enable-sccache == 'true' uses: mozilla-actions/sccache-action@v0.0.9 + - name: Pre-start sccache server + if: inputs.enable-sccache == 'true' + shell: bash + run: sccache --start-server & + - name: Install Protoc (for lance-encoding build step) if: runner.os != 'Windows' - uses: arduino/setup-protoc@v3 - with: - version: "29.3" - repo-token: ${{ inputs.repo-token }} + uses: ./.github/actions/setup-protoc - name: Install Ninja (for DuckDB build system) uses: seanmiddleditch/gha-setup-ninja@master diff --git a/.github/workflows/bench-pr.yml b/.github/workflows/bench-pr.yml index 539c9e45ef9..1e4c062f5a2 100644 --- a/.github/workflows/bench-pr.yml +++ b/.github/workflows/bench-pr.yml @@ -57,6 +57,8 @@ jobs: - uses: actions/checkout@v6 with: ref: ${{ github.event.pull_request.head.sha }} + - name: Setup benchmark environment + run: sudo bash scripts/setup-benchmark.sh - uses: ./.github/actions/setup-rust with: repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml index 35c2c057cb8..852baa26cda 100644 --- a/.github/workflows/bench.yml +++ b/.github/workflows/bench.yml @@ -54,6 +54,8 @@ jobs: with: sccache: s3 - uses: actions/checkout@v6 + - name: Setup benchmark environment + run: sudo bash scripts/setup-benchmark.sh - uses: ./.github/actions/setup-rust with: repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 98ad552ae4d..29d8b7c3590 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -482,55 +482,41 @@ jobs: --target x86_64-unknown-linux-gnu \ -p vortex-buffer -p vortex-ffi -p vortex-fastlanes -p vortex-fsst -p vortex-alp -p vortex-array - # cuda-build: - # if: github.repository == 'vortex-data/vortex' - # name: "CUDA build" - # timeout-minutes: 120 - # runs-on: runs-on=${{ github.run_id }}/runner=gpu/tag=cuda-build - # steps: - # - uses: runs-on/action@v2 - # with: - # sccache: s3 - # - uses: actions/checkout@v6 - # - uses: ./.github/actions/setup-rust - # with: - # repo-token: ${{ secrets.GITHUB_TOKEN }} - # - name: Build CUDA crates - # run: | - # cargo build --locked --all-features --all-targets \ - # -p vortex-cuda \ - # -p vortex-cub \ - # -p vortex-nvcomp \ - # -p gpu-scan-cli \ - # -p vortex-test-e2e-cuda - - # cuda-lint: - # if: github.repository == 'vortex-data/vortex' - # name: "CUDA (lint)" - # timeout-minutes: 120 - # runs-on: runs-on=${{ github.run_id }}/runner=gpu/tag=cuda-lint - # steps: - # - uses: runs-on/action@v2 - # with: - # sccache: s3 - # - uses: actions/checkout@v6 - # - uses: ./.github/actions/setup-rust - # with: - # repo-token: ${{ secrets.GITHUB_TOKEN }} - # - name: Clippy CUDA crates - # run: | - # cargo clippy --locked --all-features --all-targets \ - # -p vortex-cuda \ - # -p vortex-cub \ - # -p vortex-nvcomp \ - # -p gpu-scan-cli \ - # -p vortex-test-e2e-cuda \ - # -- -D warnings + cuda-build-lint: + if: github.repository == 'vortex-data/vortex' + name: "CUDA build & lint" + timeout-minutes: 120 + runs-on: runs-on=${{ github.run_id }}/runner=gpu/tag=cuda-build + steps: + - uses: runs-on/action@v2 + with: + sccache: s3 + - uses: actions/checkout@v6 + - uses: ./.github/actions/setup-rust + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Build CUDA crates + run: | + cargo build --locked --all-features --all-targets \ + -p vortex-cuda \ + -p vortex-cub \ + -p vortex-nvcomp \ + -p gpu-scan-cli \ + -p vortex-test-e2e-cuda + - name: Clippy CUDA crates + run: | + cargo clippy --locked --all-features --all-targets \ + -p vortex-cuda \ + -p vortex-cub \ + -p vortex-nvcomp \ + -p gpu-scan-cli \ + -p vortex-test-e2e-cuda \ + -- -D warnings cuda-test: if: github.repository == 'vortex-data/vortex' name: "CUDA tests" - timeout-minutes: 120 + timeout-minutes: 30 runs-on: runs-on=${{ github.run_id }}/runner=gpu/tag=cuda-tests steps: - uses: runs-on/action@v2 @@ -565,67 +551,67 @@ jobs: --target x86_64-unknown-linux-gnu \ --verbose - # cuda-test-sanitizer: - # if: github.repository == 'vortex-data/vortex' - # name: "CUDA tests (sanitizer)" - # timeout-minutes: 120 - # runs-on: runs-on=${{ github.run_id }}/runner=gpu/tag=cuda-test-sanitizer - # steps: - # - uses: runs-on/action@v2 - # with: - # sccache: s3 - # - name: Display NVIDIA SMI details - # run: | - # nvidia-smi - # nvidia-smi -L - # nvidia-smi -q -d Memory - # - uses: actions/checkout@v6 - # - uses: ./.github/actions/setup-rust - # with: - # repo-token: ${{ secrets.GITHUB_TOKEN }} - # - name: CUDA - memcheck - # env: - # CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: compute-sanitizer --tool memcheck --leak-check=full --error-exitcode 1 - # run: cargo test --locked -p vortex-cuda --all-features --target x86_64-unknown-linux-gnu - # # TODO(joe): try to re-enable, This is hanging in CI. - # # - name: CUDA - racecheck - # # env: - # # CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: compute-sanitizer --tool racecheck --error-exitcode 1 - # # run: cargo test --locked -p vortex-cuda --all-features --target x86_64-unknown-linux-gnu - # - name: CUDA - synccheck - # env: - # CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: compute-sanitizer --tool synccheck --error-exitcode 1 - # run: cargo test --locked -p vortex-cuda --all-features --target x86_64-unknown-linux-gnu - # - name: CUDA - initcheck - # env: - # CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: compute-sanitizer --tool initcheck --error-exitcode 1 - # run: cargo test --locked -p vortex-cuda --all-features --target x86_64-unknown-linux-gnu - - # cuda-test-cudf: - # if: github.repository == 'vortex-data/vortex' - # name: "CUDA tests (cudf)" - # timeout-minutes: 120 - # runs-on: runs-on=${{ github.run_id }}/runner=gpu/tag=cuda-test-cudf - # steps: - # - uses: runs-on/action@v2 - # with: - # sccache: s3 - # - name: Display NVIDIA SMI details - # run: | - # nvidia-smi - # nvidia-smi -L - # nvidia-smi -q -d Memory - # - uses: actions/checkout@v6 - # - uses: ./.github/actions/setup-rust - # with: - # repo-token: ${{ secrets.GITHUB_TOKEN }} - # - name: Build cudf test library - # run: cargo build --locked -p vortex-test-e2e-cuda --target x86_64-unknown-linux-gnu - # - name: Download and run cudf-test-harness - # run: | - # curl -fsSL https://github.com/vortex-data/cudf-test-harness/releases/latest/download/cudf-test-harness-x86_64.tar.gz | tar -xz - # cd cudf-test-harness-x86_64 - # compute-sanitizer --tool memcheck --error-exitcode 1 ./cudf-test-harness check $GITHUB_WORKSPACE/target/x86_64-unknown-linux-gnu/debug/libvortex_test_e2e_cuda.so + cuda-test-sanitizer: + if: github.repository == 'vortex-data/vortex' + name: "CUDA tests (sanitizer)" + timeout-minutes: 30 + runs-on: runs-on=${{ github.run_id }}/runner=gpu/tag=cuda-test-sanitizer + steps: + - uses: runs-on/action@v2 + with: + sccache: s3 + - name: Display NVIDIA SMI details + run: | + nvidia-smi + nvidia-smi -L + nvidia-smi -q -d Memory + - uses: actions/checkout@v6 + - uses: ./.github/actions/setup-rust + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: CUDA - memcheck + env: + CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: compute-sanitizer --tool memcheck --leak-check=full --error-exitcode 1 + run: cargo test --locked -p vortex-cuda --all-features --target x86_64-unknown-linux-gnu + # TODO(joe): try to re-enable, This is hanging in CI. + # - name: CUDA - racecheck + # env: + # CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: compute-sanitizer --tool racecheck --error-exitcode 1 + # run: cargo test --locked -p vortex-cuda --all-features --target x86_64-unknown-linux-gnu + - name: CUDA - synccheck + env: + CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: compute-sanitizer --tool synccheck --error-exitcode 1 + run: cargo test --locked -p vortex-cuda --all-features --target x86_64-unknown-linux-gnu + - name: CUDA - initcheck + env: + CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: compute-sanitizer --tool initcheck --error-exitcode 1 + run: cargo test --locked -p vortex-cuda --all-features --target x86_64-unknown-linux-gnu + + cuda-test-cudf: + if: github.repository == 'vortex-data/vortex' + name: "CUDA tests (cudf)" + timeout-minutes: 30 + runs-on: runs-on=${{ github.run_id }}/runner=gpu/tag=cuda-test-cudf + steps: + - uses: runs-on/action@v2 + with: + sccache: s3 + - name: Display NVIDIA SMI details + run: | + nvidia-smi + nvidia-smi -L + nvidia-smi -q -d Memory + - uses: actions/checkout@v6 + - uses: ./.github/actions/setup-rust + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Build cudf test library + run: cargo build --locked -p vortex-test-e2e-cuda --target x86_64-unknown-linux-gnu + - name: Download and run cudf-test-harness + run: | + curl -fsSL https://github.com/vortex-data/cudf-test-harness/releases/latest/download/cudf-test-harness-x86_64.tar.gz | tar -xz + cd cudf-test-harness-x86_64 + compute-sanitizer --tool memcheck --error-exitcode 1 ./cudf-test-harness check $GITHUB_WORKSPACE/target/x86_64-unknown-linux-gnu/debug/libvortex_test_e2e_cuda.so rust-test-other: name: "Rust tests (${{ matrix.os }})" @@ -740,6 +726,8 @@ jobs: with: sccache: s3 - uses: actions/checkout@v6 + - name: Setup benchmark environment + run: sudo bash scripts/setup-benchmark.sh - uses: ./.github/actions/setup-rust with: repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/sql-benchmarks.yml b/.github/workflows/sql-benchmarks.yml index b5152e36148..bb8233ae56e 100644 --- a/.github/workflows/sql-benchmarks.yml +++ b/.github/workflows/sql-benchmarks.yml @@ -118,6 +118,8 @@ jobs: - uses: actions/checkout@v6 if: inputs.mode != 'pr' + - name: Setup benchmark environment + run: sudo bash scripts/setup-benchmark.sh - uses: ./.github/actions/setup-rust with: repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/Cargo.lock b/Cargo.lock index 127d058d719..d380a3d6229 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9730,6 +9730,7 @@ dependencies = [ "tabled", "termtree", "tracing", + "uuid", "vortex-array", "vortex-buffer", "vortex-error", @@ -10140,6 +10141,8 @@ dependencies = [ "libfuzzer-sys", "strum 0.27.2", "tokio", + "tracing", + "tracing-subscriber", "vortex", "vortex-array", "vortex-btrblocks", diff --git a/docs/api/cpp/index.rst b/docs/api/cpp/index.rst index 43d6c2ae409..befb980f6a9 100644 --- a/docs/api/cpp/index.rst +++ b/docs/api/cpp/index.rst @@ -20,7 +20,7 @@ Installation The C++ bindings are built using CMake. Requirements: * CMake 3.22 or higher -* C++17 compatible compiler +* C++20 compatible compiler * Rust toolchain (for building the underlying Rust library) .. code-block:: bash diff --git a/docs/developer-guide/internals/vtables.md b/docs/developer-guide/internals/vtables.md index 45ca3fe558a..e90c8a8f441 100644 --- a/docs/developer-guide/internals/vtables.md +++ b/docs/developer-guide/internals/vtables.md @@ -63,21 +63,21 @@ registered in the session by their ID. For a concept `Foo`, the components are organized into these files: -| File | Contains | -|---------------|-----------------------------------------------------------------------| -| `vtable.rs` | `FooVTable` trait definition | -| `typed.rs` | `Foo` data struct, inherent methods, `Deref` impl | -| `erased.rs` | `FooRef` struct, `DynFoo` sealed trait, blanket impl | -| `plugin.rs` | `FooPlugin` trait, registration | -| `matcher.rs` | Downcasting helpers (`is`, `as_`, `as_opt`, pattern matching traits) | +| File | Contains | +|--------------|----------------------------------------------------------------------| +| `vtable.rs` | `FooVTable` trait definition | +| `typed.rs` | `Foo` data struct, inherent methods, `Deref` impl | +| `erased.rs` | `FooRef` struct, `DynFoo` sealed trait, blanket impl | +| `plugin.rs` | `FooPlugin` trait, registration | +| `matcher.rs` | Downcasting helpers (`is`, `as_`, `as_opt`, pattern matching traits) | For Array encodings, each encoding has its own module (e.g. `arrays/primitive/`): -| File | Contains | -|-------------------------|-------------------------------------------------------------| -| `arrays/foo/mod.rs` | `V::Array` associated type, encoding-specific methods on it | -| `arrays/foo/vtable.rs` | `ArrayVTable` impl for this encoding | -| `arrays/foo/compute/` | Compute kernel implementations | +| File | Contains | +|------------------------|-------------------------------------------------------------| +| `arrays/foo/mod.rs` | `V::Array` associated type, encoding-specific methods on it | +| `arrays/foo/vtable.rs` | `ArrayVTable` impl for this encoding | +| `arrays/foo/compute/` | Compute kernel implementations | ## Example: ExtDType diff --git a/encodings/fastlanes/benches/bitpacking_take.rs b/encodings/fastlanes/benches/bitpacking_take.rs index 9bf7ea4db79..271ba447af2 100644 --- a/encodings/fastlanes/benches/bitpacking_take.rs +++ b/encodings/fastlanes/benches/bitpacking_take.rs @@ -28,10 +28,10 @@ fn main() { #[divan::bench] fn take_10_stratified(bencher: Bencher) { - let values = fixture(1_000_000, 8); + let values = fixture(65_536, 8); let uncompressed = PrimitiveArray::new(values, Validity::NonNullable); let packed = bitpack_to_best_bit_width(&uncompressed).unwrap(); - let indices = PrimitiveArray::from_iter((0..10).map(|i| i * 10_000)); + let indices = PrimitiveArray::from_iter((0..10).map(|i| i * 6_553)); bencher .with_inputs(|| (&packed, &indices, LEGACY_SESSION.create_execution_ctx())) @@ -46,7 +46,7 @@ fn take_10_stratified(bencher: Bencher) { #[divan::bench] fn take_10_contiguous(bencher: Bencher) { - let values = fixture(1_000_000, 8); + let values = fixture(65_536, 8); let uncompressed = PrimitiveArray::new(values, Validity::NonNullable); let packed = bitpack_to_best_bit_width(&uncompressed).unwrap(); let indices = buffer![0..10].into_array(); @@ -64,7 +64,7 @@ fn take_10_contiguous(bencher: Bencher) { #[divan::bench] fn take_10k_random(bencher: Bencher) { - let values = fixture(1_000_000, 8); + let values = fixture(65_536, 8); let range = Uniform::new(0, values.len()).unwrap(); let uncompressed = PrimitiveArray::new(values, Validity::NonNullable); let packed = bitpack_to_best_bit_width(&uncompressed).unwrap(); @@ -85,7 +85,7 @@ fn take_10k_random(bencher: Bencher) { #[divan::bench] fn take_10k_contiguous(bencher: Bencher) { - let values = fixture(1_000_000, 8); + let values = fixture(65_536, 8); let uncompressed = PrimitiveArray::new(values, Validity::NonNullable); let packed = bitpack_to_best_bit_width(&uncompressed).unwrap(); let indices = PrimitiveArray::from_iter(0..10_000); @@ -102,11 +102,11 @@ fn take_10k_contiguous(bencher: Bencher) { } #[divan::bench] -fn take_200k_dispersed(bencher: Bencher) { - let values = fixture(1_000_000, 8); +fn take_10k_dispersed(bencher: Bencher) { + let values = fixture(65_536, 8); let uncompressed = PrimitiveArray::new(values.clone(), Validity::NonNullable); let packed = bitpack_to_best_bit_width(&uncompressed).unwrap(); - let indices = PrimitiveArray::from_iter((0..200_000).map(|i| (i * 42) % values.len() as u64)); + let indices = PrimitiveArray::from_iter((0..10_000).map(|i| (i * 42) % values.len() as u64)); bencher .with_inputs(|| (&packed, &indices, LEGACY_SESSION.create_execution_ctx())) @@ -120,11 +120,11 @@ fn take_200k_dispersed(bencher: Bencher) { } #[divan::bench] -fn take_200k_first_chunk_only(bencher: Bencher) { - let values = fixture(1_000_000, 8); +fn take_10k_first_chunk_only(bencher: Bencher) { + let values = fixture(65_536, 8); let uncompressed = PrimitiveArray::new(values, Validity::NonNullable); let packed = bitpack_to_best_bit_width(&uncompressed).unwrap(); - let indices = PrimitiveArray::from_iter((0..200_000).map(|i| ((i * 42) % 1024) as u64)); + let indices = PrimitiveArray::from_iter((0..10_000).map(|i| ((i * 42) % 1024) as u64)); bencher .with_inputs(|| (&packed, &indices, LEGACY_SESSION.create_execution_ctx())) @@ -154,8 +154,8 @@ fn fixture(len: usize, bits: usize) -> Buffer { // I've iterated on both thresholds (1) and (2) using this collection of benchmarks, and those // were roughly the best values that I found. -const BIG_BASE2: u32 = 1048576; -const NUM_EXCEPTIONS: u32 = 10000; +const BIG_BASE2: u32 = 65536; +const NUM_EXCEPTIONS: u32 = 1024; #[divan::bench] fn patched_take_10_stratified(bencher: Bencher) { @@ -169,7 +169,7 @@ fn patched_take_10_stratified(bencher: Bencher) { NUM_EXCEPTIONS as usize ); - let indices = PrimitiveArray::from_iter((0..10).map(|i| i * 10_000)); + let indices = PrimitiveArray::from_iter((0..10).map(|i| i * 6_653)); bencher .with_inputs(|| (&packed, &indices, LEGACY_SESSION.create_execution_ctx())) @@ -273,11 +273,11 @@ fn patched_take_10k_contiguous_patches(bencher: Bencher) { } #[divan::bench] -fn patched_take_200k_dispersed(bencher: Bencher) { +fn patched_take_10k_dispersed(bencher: Bencher) { let values = (0u32..BIG_BASE2 + NUM_EXCEPTIONS).collect::>(); let uncompressed = PrimitiveArray::new(values.clone(), Validity::NonNullable); let packed = bitpack_to_best_bit_width(&uncompressed).unwrap(); - let indices = PrimitiveArray::from_iter((0..200_000).map(|i| (i * 42) % values.len() as u64)); + let indices = PrimitiveArray::from_iter((0..10_000).map(|i| (i * 42) % values.len() as u64)); bencher .with_inputs(|| (&packed, &indices, LEGACY_SESSION.create_execution_ctx())) @@ -291,11 +291,11 @@ fn patched_take_200k_dispersed(bencher: Bencher) { } #[divan::bench] -fn patched_take_200k_first_chunk_only(bencher: Bencher) { +fn patched_take_10k_first_chunk_only(bencher: Bencher) { let values = (0u32..BIG_BASE2 + NUM_EXCEPTIONS).collect::>(); let uncompressed = PrimitiveArray::new(values, Validity::NonNullable); let packed = bitpack_to_best_bit_width(&uncompressed).unwrap(); - let indices = PrimitiveArray::from_iter((0..200_000).map(|i| ((i * 42) % 1024) as u64)); + let indices = PrimitiveArray::from_iter((0..10_000).map(|i| ((i * 42) % 1024) as u64)); bencher .with_inputs(|| (&packed, &indices, LEGACY_SESSION.create_execution_ctx())) diff --git a/encodings/runend/Cargo.toml b/encodings/runend/Cargo.toml index c6823f16bd3..01a5b8d7a3e 100644 --- a/encodings/runend/Cargo.toml +++ b/encodings/runend/Cargo.toml @@ -48,3 +48,7 @@ harness = false [[bench]] name = "run_end_compress" harness = false + +[[bench]] +name = "run_end_decode" +harness = false diff --git a/encodings/runend/benches/run_end_compress.rs b/encodings/runend/benches/run_end_compress.rs index 533045c7314..04fcfce1323 100644 --- a/encodings/runend/benches/run_end_compress.rs +++ b/encodings/runend/benches/run_end_compress.rs @@ -36,12 +36,6 @@ const BENCH_ARGS: &[(usize, usize)] = &[ (100_000, 256), (100_000, 1024), (100_000, 4096), - (1_000_000, 4), - (1_000_000, 16), - (1_000_000, 256), - (1_000_000, 1024), - (1_000_000, 4096), - (1_000_000, 8192), ]; #[divan::bench(args = BENCH_ARGS)] diff --git a/encodings/runend/benches/run_end_decode.rs b/encodings/runend/benches/run_end_decode.rs new file mode 100644 index 00000000000..9f64beabff5 --- /dev/null +++ b/encodings/runend/benches/run_end_decode.rs @@ -0,0 +1,380 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::unwrap_used, clippy::cast_possible_truncation)] + +use std::fmt; + +use divan::Bencher; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::compute::warm_up_vtables; +use vortex_array::validity::Validity; +use vortex_buffer::BitBuffer; +use vortex_buffer::BufferMut; +use vortex_runend::decompress_bool::runend_decode_bools; + +fn main() { + warm_up_vtables(); + divan::main(); +} + +/// Distribution types for bool benchmarks +#[derive(Clone, Copy)] +enum BoolDistribution { + /// Alternating true/false (50/50) + Alternating, + /// Mostly true (90% true runs) + MostlyTrue, + /// Mostly false (90% false runs) + MostlyFalse, + /// All true + AllTrue, + /// All false + AllFalse, +} + +impl fmt::Display for BoolDistribution { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BoolDistribution::Alternating => write!(f, "alternating"), + BoolDistribution::MostlyTrue => write!(f, "mostly_true"), + BoolDistribution::MostlyFalse => write!(f, "mostly_false"), + BoolDistribution::AllTrue => write!(f, "all_true"), + BoolDistribution::AllFalse => write!(f, "all_false"), + } + } +} + +#[derive(Clone, Copy)] +struct BoolBenchArgs { + total_length: usize, + avg_run_length: usize, + distribution: BoolDistribution, +} + +impl fmt::Display for BoolBenchArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}_{}_{}", + self.total_length, self.avg_run_length, self.distribution + ) + } +} + +/// Creates bool test data with configurable distribution +fn create_bool_test_data( + total_length: usize, + avg_run_length: usize, + distribution: BoolDistribution, +) -> (PrimitiveArray, BoolArray) { + let mut ends = BufferMut::::with_capacity(total_length / avg_run_length + 1); + let mut values = Vec::with_capacity(total_length / avg_run_length + 1); + + let mut pos = 0usize; + let mut run_index = 0usize; + + while pos < total_length { + let run_len = avg_run_length.min(total_length - pos); + pos += run_len; + ends.push(pos as u32); + + let val = match distribution { + BoolDistribution::Alternating => run_index.is_multiple_of(2), + BoolDistribution::MostlyTrue => !run_index.is_multiple_of(10), // 90% true + BoolDistribution::MostlyFalse => run_index.is_multiple_of(10), // 10% true (90% false) + BoolDistribution::AllTrue => true, + BoolDistribution::AllFalse => false, + }; + values.push(val); + run_index += 1; + } + + ( + PrimitiveArray::new(ends.freeze(), Validity::NonNullable), + BoolArray::from(BitBuffer::from(values)), + ) +} + +// Medium size: 10k elements with various run lengths and distributions +const BOOL_ARGS: &[BoolBenchArgs] = &[ + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 2, + distribution: BoolDistribution::Alternating, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::Alternating, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 100, + distribution: BoolDistribution::Alternating, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 1000, + distribution: BoolDistribution::Alternating, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 2, + distribution: BoolDistribution::MostlyTrue, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::MostlyTrue, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 100, + distribution: BoolDistribution::MostlyTrue, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 1000, + distribution: BoolDistribution::MostlyTrue, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 2, + distribution: BoolDistribution::MostlyFalse, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::MostlyFalse, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 100, + distribution: BoolDistribution::MostlyFalse, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 1000, + distribution: BoolDistribution::MostlyFalse, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 2, + distribution: BoolDistribution::AllTrue, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::AllTrue, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 100, + distribution: BoolDistribution::AllTrue, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 1000, + distribution: BoolDistribution::AllTrue, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 2, + distribution: BoolDistribution::AllFalse, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::AllFalse, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 100, + distribution: BoolDistribution::AllFalse, + }, + BoolBenchArgs { + total_length: 10_000, + avg_run_length: 1000, + distribution: BoolDistribution::AllFalse, + }, +]; + +#[divan::bench(args = BOOL_ARGS)] +fn decode_bool(bencher: Bencher, args: BoolBenchArgs) { + let BoolBenchArgs { + total_length, + avg_run_length, + distribution, + } = args; + let (ends, values) = create_bool_test_data(total_length, avg_run_length, distribution); + bencher + .with_inputs(|| (ends.clone(), values.clone())) + .bench_refs(|(ends, values)| { + runend_decode_bools(ends.clone(), values.clone(), 0, total_length) + }); +} + +/// Validity distribution for nullable benchmarks +#[derive(Clone, Copy)] +enum ValidityDistribution { + /// 90% valid + MostlyValid, + /// 50% valid + HalfValid, + /// 10% valid + MostlyNull, +} + +impl fmt::Display for ValidityDistribution { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ValidityDistribution::MostlyValid => write!(f, "mostly_valid"), + ValidityDistribution::HalfValid => write!(f, "half_valid"), + ValidityDistribution::MostlyNull => write!(f, "mostly_null"), + } + } +} + +#[derive(Clone, Copy)] +struct NullableBoolBenchArgs { + total_length: usize, + avg_run_length: usize, + distribution: BoolDistribution, + validity: ValidityDistribution, +} + +impl fmt::Display for NullableBoolBenchArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}_{}_{}_{}", + self.total_length, self.avg_run_length, self.distribution, self.validity + ) + } +} + +/// Creates nullable bool test data with configurable distribution and validity +fn create_nullable_bool_test_data( + total_length: usize, + avg_run_length: usize, + distribution: BoolDistribution, + validity: ValidityDistribution, +) -> (PrimitiveArray, BoolArray) { + let mut ends = BufferMut::::with_capacity(total_length / avg_run_length + 1); + let mut values = Vec::with_capacity(total_length / avg_run_length + 1); + let mut validity_bits = Vec::with_capacity(total_length / avg_run_length + 1); + + let mut pos = 0usize; + let mut run_index = 0usize; + + while pos < total_length { + let run_len = avg_run_length.min(total_length - pos); + pos += run_len; + ends.push(pos as u32); + + let val = match distribution { + BoolDistribution::Alternating => run_index.is_multiple_of(2), + BoolDistribution::MostlyTrue => !run_index.is_multiple_of(10), + BoolDistribution::MostlyFalse => run_index.is_multiple_of(10), + BoolDistribution::AllTrue => true, + BoolDistribution::AllFalse => false, + }; + values.push(val); + + let is_valid = match validity { + ValidityDistribution::MostlyValid => !run_index.is_multiple_of(10), + ValidityDistribution::HalfValid => run_index.is_multiple_of(2), + ValidityDistribution::MostlyNull => run_index.is_multiple_of(10), + }; + validity_bits.push(is_valid); + + run_index += 1; + } + + ( + PrimitiveArray::new(ends.freeze(), Validity::NonNullable), + BoolArray::new( + BitBuffer::from(values), + Validity::from(BitBuffer::from(validity_bits)), + ), + ) +} + +const NULLABLE_BOOL_ARGS: &[NullableBoolBenchArgs] = &[ + // Alternating with different validity + NullableBoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::Alternating, + validity: ValidityDistribution::MostlyValid, + }, + NullableBoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::Alternating, + validity: ValidityDistribution::HalfValid, + }, + NullableBoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::Alternating, + validity: ValidityDistribution::MostlyNull, + }, + // MostlyTrue with different validity + NullableBoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::MostlyTrue, + validity: ValidityDistribution::MostlyValid, + }, + NullableBoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::MostlyTrue, + validity: ValidityDistribution::HalfValid, + }, + NullableBoolBenchArgs { + total_length: 10_000, + avg_run_length: 10, + distribution: BoolDistribution::MostlyTrue, + validity: ValidityDistribution::MostlyNull, + }, + // Different run lengths with MostlyValid + NullableBoolBenchArgs { + total_length: 10_000, + avg_run_length: 2, + distribution: BoolDistribution::Alternating, + validity: ValidityDistribution::MostlyValid, + }, + NullableBoolBenchArgs { + total_length: 10_000, + avg_run_length: 100, + distribution: BoolDistribution::Alternating, + validity: ValidityDistribution::MostlyValid, + }, + NullableBoolBenchArgs { + total_length: 10_000, + avg_run_length: 1000, + distribution: BoolDistribution::Alternating, + validity: ValidityDistribution::MostlyValid, + }, +]; + +#[divan::bench(args = NULLABLE_BOOL_ARGS)] +fn decode_bool_nullable(bencher: Bencher, args: NullableBoolBenchArgs) { + let NullableBoolBenchArgs { + total_length, + avg_run_length, + distribution, + validity, + } = args; + let (ends, values) = + create_nullable_bool_test_data(total_length, avg_run_length, distribution, validity); + bencher + .with_inputs(|| (ends.clone(), values.clone())) + .bench_refs(|(ends, values)| { + runend_decode_bools(ends.clone(), values.clone(), 0, total_length) + }); +} diff --git a/encodings/runend/public-api.lock b/encodings/runend/public-api.lock index 4f9516a32fc..433f06037fe 100644 --- a/encodings/runend/public-api.lock +++ b/encodings/runend/public-api.lock @@ -2,18 +2,20 @@ pub mod vortex_runend pub mod vortex_runend::compress -pub fn vortex_runend::compress::runend_decode_bools(ends: vortex_array::arrays::primitive::array::PrimitiveArray, values: vortex_array::arrays::bool::array::BoolArray, offset: usize, length: usize) -> vortex_error::VortexResult - pub fn vortex_runend::compress::runend_decode_primitive(ends: vortex_array::arrays::primitive::array::PrimitiveArray, values: vortex_array::arrays::primitive::array::PrimitiveArray, offset: usize, length: usize) -> vortex_error::VortexResult -pub fn vortex_runend::compress::runend_decode_typed_bool(run_ends: impl core::iter::traits::iterator::Iterator, values: &vortex_buffer::bit::buf::BitBuffer, values_validity: vortex_mask::Mask, values_nullability: vortex_array::dtype::nullability::Nullability, length: usize) -> vortex_array::arrays::bool::array::BoolArray - pub fn vortex_runend::compress::runend_decode_typed_primitive(run_ends: impl core::iter::traits::iterator::Iterator, values: &[T], values_validity: vortex_mask::Mask, values_nullability: vortex_array::dtype::nullability::Nullability, length: usize) -> vortex_array::arrays::primitive::array::PrimitiveArray pub fn vortex_runend::compress::runend_decode_varbinview(ends: vortex_array::arrays::primitive::array::PrimitiveArray, values: vortex_array::arrays::varbinview::array::VarBinViewArray, offset: usize, length: usize) -> vortex_error::VortexResult pub fn vortex_runend::compress::runend_encode(array: &vortex_array::arrays::primitive::array::PrimitiveArray) -> (vortex_array::arrays::primitive::array::PrimitiveArray, vortex_array::array::ArrayRef) +pub mod vortex_runend::decompress_bool + +pub fn vortex_runend::decompress_bool::runend_decode_bools(ends: vortex_array::arrays::primitive::array::PrimitiveArray, values: vortex_array::arrays::bool::array::BoolArray, offset: usize, length: usize) -> vortex_error::VortexResult + +pub fn vortex_runend::decompress_bool::runend_decode_typed_bool(run_ends: impl core::iter::traits::iterator::Iterator, values: &vortex_buffer::bit::buf::BitBuffer, values_validity: vortex_mask::Mask, values_nullability: vortex_array::dtype::nullability::Nullability, length: usize) -> vortex_array::array::ArrayRef + pub struct vortex_runend::RunEndArray impl vortex_runend::RunEndArray diff --git a/encodings/runend/src/array.rs b/encodings/runend/src/array.rs index f943b747cda..2201cb57107 100644 --- a/encodings/runend/src/array.rs +++ b/encodings/runend/src/array.rs @@ -39,10 +39,10 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_panic; use vortex_session::VortexSession; -use crate::compress::runend_decode_bools; use crate::compress::runend_decode_primitive; use crate::compress::runend_decode_varbinview; use crate::compress::runend_encode; +use crate::decompress_bool::runend_decode_bools; use crate::kernel::PARENT_KERNELS; use crate::rules::RULES; @@ -486,7 +486,7 @@ pub(super) fn run_end_canonicalize( Ok(match array.dtype() { DType::Bool(_) => { let bools = array.values().clone().execute_as("values", ctx)?; - runend_decode_bools(pends, bools, array.offset(), array.len())?.into_array() + runend_decode_bools(pends, bools, array.offset(), array.len())? } DType::Primitive(..) => { let pvalues = array.values().clone().execute_as("values", ctx)?; diff --git a/encodings/runend/src/compress.rs b/encodings/runend/src/compress.rs index ff3462d961c..0a841c28194 100644 --- a/encodings/runend/src/compress.rs +++ b/encodings/runend/src/compress.rs @@ -188,24 +188,6 @@ pub fn runend_decode_primitive( })) } -pub fn runend_decode_bools( - ends: PrimitiveArray, - values: BoolArray, - offset: usize, - length: usize, -) -> VortexResult { - let validity_mask = values.validity_mask()?; - Ok(match_each_unsigned_integer_ptype!(ends.ptype(), |E| { - runend_decode_typed_bool( - trimmed_ends_iter(ends.as_slice::(), offset, length), - &values.to_bit_buffer(), - validity_mask, - values.dtype().nullability(), - length, - ) - })) -} - /// Decode a run-end encoded slice of values into a flat `Buffer` and `Validity`. /// /// This is the core decode loop shared by primitive and varbinview run-end decoding. @@ -285,47 +267,6 @@ pub fn runend_decode_typed_primitive( PrimitiveArray::new(decoded, validity) } -pub fn runend_decode_typed_bool( - run_ends: impl Iterator, - values: &BitBuffer, - values_validity: Mask, - values_nullability: Nullability, - length: usize, -) -> BoolArray { - match values_validity { - Mask::AllTrue(_) => { - let mut decoded = BitBufferMut::with_capacity(length); - for (end, value) in run_ends.zip_eq(values.iter()) { - decoded.append_n(value, end - decoded.len()); - } - BoolArray::new(decoded.freeze(), values_nullability.into()) - } - Mask::AllFalse(_) => BoolArray::new(BitBuffer::new_unset(length), Validity::AllInvalid), - Mask::Values(mask) => { - let mut decoded = BitBufferMut::with_capacity(length); - let mut decoded_validity = BitBufferMut::with_capacity(length); - for (end, value) in run_ends.zip_eq( - values - .iter() - .zip(mask.bit_buffer().iter()) - .map(|(v, is_valid)| is_valid.then_some(v)), - ) { - match value { - None => { - decoded_validity.append_n(false, end - decoded.len()); - decoded.append_n(false, end - decoded.len()); - } - Some(value) => { - decoded_validity.append_n(true, end - decoded.len()); - decoded.append_n(value, end - decoded.len()); - } - } - } - BoolArray::new(decoded.freeze(), Validity::from(decoded_validity.freeze())) - } - } -} - /// Decode a run-end encoded VarBinView array by expanding views directly. pub fn runend_decode_varbinview( ends: PrimitiveArray, diff --git a/encodings/runend/src/compute/compare.rs b/encodings/runend/src/compute/compare.rs index 0b02c474dc1..6cc4ae55b4d 100644 --- a/encodings/runend/src/compute/compare.rs +++ b/encodings/runend/src/compute/compare.rs @@ -16,7 +16,7 @@ use vortex_error::VortexResult; use crate::RunEndArray; use crate::RunEndVTable; -use crate::compress::runend_decode_bools; +use crate::decompress_bool::runend_decode_bools; impl CompareKernel for RunEndVTable { fn compare( @@ -31,13 +31,13 @@ impl CompareKernel for RunEndVTable { ConstantArray::new(const_scalar, lhs.values().len()).into_array(), Operator::from(operator), )?; - let decoded = runend_decode_bools( + return runend_decode_bools( lhs.ends().clone().execute::(ctx)?, values.execute::(ctx)?, lhs.offset(), lhs.len(), - )?; - return Ok(Some(decoded.into_array())); + ) + .map(Some); } // Otherwise, fall back diff --git a/encodings/runend/src/decompress_bool.rs b/encodings/runend/src/decompress_bool.rs new file mode 100644 index 00000000000..407745d9154 --- /dev/null +++ b/encodings/runend/src/decompress_bool.rs @@ -0,0 +1,381 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Optimized run-end decoding for boolean arrays. +//! +//! Uses an adaptive strategy that pre-fills the buffer with the majority value +//! (0s or 1s) and only fills the minority runs, minimizing work for skewed distributions. + +use itertools::Itertools; +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::match_each_unsigned_integer_ptype; +use vortex_array::scalar::Scalar; +use vortex_array::validity::Validity; +use vortex_buffer::BitBuffer; +use vortex_buffer::BitBufferMut; +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::iter::trimmed_ends_iter; + +/// Threshold for number of runs below which we use sequential append instead of prefill. +/// With few runs, the overhead of prefilling the entire buffer dominates. +const PREFILL_RUN_THRESHOLD: usize = 32; + +/// Decodes run-end encoded boolean values into a flat `BoolArray`. +pub fn runend_decode_bools( + ends: PrimitiveArray, + values: BoolArray, + offset: usize, + length: usize, +) -> VortexResult { + let validity = values.validity_mask()?; + let values_buf = values.to_bit_buffer(); + let nullability = values.dtype().nullability(); + + // Fast path for few runs with no offset - avoids iterator overhead + let num_runs = values_buf.len(); + if offset == 0 && num_runs < PREFILL_RUN_THRESHOLD { + return Ok(match_each_unsigned_integer_ptype!(ends.ptype(), |E| { + decode_few_runs_no_offset( + ends.as_slice::(), + &values_buf, + validity, + nullability, + length, + ) + })); + } + + Ok(match_each_unsigned_integer_ptype!(ends.ptype(), |E| { + runend_decode_typed_bool( + trimmed_ends_iter(ends.as_slice::(), offset, length), + &values_buf, + validity, + nullability, + length, + ) + })) +} + +/// Decodes run-end encoded boolean values using an adaptive strategy. +/// +/// The strategy counts true vs false runs and chooses the optimal approach: +/// - If more true runs: pre-fill with 1s, clear false runs +/// - If more false runs: pre-fill with 0s, fill true runs +/// +/// This minimizes work for skewed distributions (e.g., sparse validity masks). +pub fn runend_decode_typed_bool( + run_ends: impl Iterator, + values: &BitBuffer, + values_validity: Mask, + values_nullability: Nullability, + length: usize, +) -> ArrayRef { + match values_validity { + Mask::AllTrue(_) => { + decode_bool_non_nullable(run_ends, values, values_nullability, length).into_array() + } + Mask::AllFalse(_) => { + ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), length) + .into_array() + } + Mask::Values(mask) => { + decode_bool_nullable(run_ends, values, mask.bit_buffer(), length).into_array() + } + } +} + +/// Fast path for few runs with no offset. Uses direct slice access to minimize overhead. +/// This avoids the `trimmed_ends_iter` iterator chain which adds significant overhead +/// for small numbers of runs. +#[inline(always)] +fn decode_few_runs_no_offset( + ends: &[E], + values: &BitBuffer, + validity: Mask, + nullability: Nullability, + length: usize, +) -> ArrayRef { + match validity { + Mask::AllTrue(_) => { + let mut decoded = BitBufferMut::with_capacity(length); + let mut prev_end = 0usize; + for (i, &end) in ends.iter().enumerate() { + let end = end.as_().min(length); + decoded.append_n(values.value(i), end - prev_end); + prev_end = end; + } + BoolArray::new(decoded.freeze(), nullability.into()).into_array() + } + Mask::AllFalse(_) => { + ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), length) + .into_array() + } + Mask::Values(mask) => { + let validity_buf = mask.bit_buffer(); + let mut decoded = BitBufferMut::with_capacity(length); + let mut decoded_validity = BitBufferMut::with_capacity(length); + let mut prev_end = 0usize; + for (i, &end) in ends.iter().enumerate() { + let end = end.as_().min(length); + let run_len = end - prev_end; + let is_valid = validity_buf.value(i); + if is_valid { + decoded_validity.append_n(true, run_len); + decoded.append_n(values.value(i), run_len); + } else { + decoded_validity.append_n(false, run_len); + decoded.append_n(false, run_len); + } + prev_end = end; + } + BoolArray::new(decoded.freeze(), Validity::from(decoded_validity.freeze())).into_array() + } + } +} + +/// Decodes run-end encoded booleans when all values are valid (non-nullable). +fn decode_bool_non_nullable( + run_ends: impl Iterator, + values: &BitBuffer, + nullability: Nullability, + length: usize, +) -> BoolArray { + let num_runs = values.len(); + + // For few runs, sequential append is faster than prefill + modify + if num_runs < PREFILL_RUN_THRESHOLD { + let mut decoded = BitBufferMut::with_capacity(length); + for (end, value) in run_ends.zip(values.iter()) { + decoded.append_n(value, end - decoded.len()); + } + return BoolArray::new(decoded.freeze(), nullability.into()); + } + + // Adaptive strategy: prefill with majority value, only flip minority runs + let prefill = values.true_count() > num_runs - values.true_count(); + let mut decoded = BitBufferMut::full(prefill, length); + let mut current_pos = 0usize; + + for (end, value) in run_ends.zip_eq(values.iter()) { + if end > current_pos && value != prefill { + // SAFETY: current_pos < end <= length == decoded.len() + unsafe { decoded.fill_range_unchecked(current_pos, end, value) }; + } + current_pos = end; + } + BoolArray::new(decoded.freeze(), nullability.into()) +} + +/// Decodes run-end encoded booleans when values may be null (nullable). +fn decode_bool_nullable( + run_ends: impl Iterator, + values: &BitBuffer, + validity_mask: &BitBuffer, + length: usize, +) -> BoolArray { + let num_runs = values.len(); + + // For few runs, sequential append is faster than prefill + modify + if num_runs < PREFILL_RUN_THRESHOLD { + return decode_nullable_sequential(run_ends, values, validity_mask, length); + } + + // Adaptive strategy: prefill each buffer with its majority value + let prefill_decoded = values.true_count() > num_runs - values.true_count(); + let prefill_valid = validity_mask.true_count() > num_runs - validity_mask.true_count(); + + let mut decoded = BitBufferMut::full(prefill_decoded, length); + let mut decoded_validity = BitBufferMut::full(prefill_valid, length); + let mut current_pos = 0usize; + + for (end, (value, is_valid)) in run_ends.zip_eq(values.iter().zip(validity_mask.iter())) { + if end > current_pos { + // SAFETY: current_pos < end <= length == decoded.len() == decoded_validity.len() + if is_valid != prefill_valid { + unsafe { decoded_validity.fill_range_unchecked(current_pos, end, is_valid) }; + } + // Decoded bit should be the actual value when valid, false when null. + let want_decoded = is_valid && value; + if want_decoded != prefill_decoded { + unsafe { decoded.fill_range_unchecked(current_pos, end, want_decoded) }; + } + current_pos = end; + } + } + BoolArray::new(decoded.freeze(), Validity::from(decoded_validity.freeze())) +} + +/// Sequential decode for few runs - avoids prefill overhead. +#[inline(always)] +fn decode_nullable_sequential( + run_ends: impl Iterator, + values: &BitBuffer, + validity_mask: &BitBuffer, + length: usize, +) -> BoolArray { + let mut decoded = BitBufferMut::with_capacity(length); + let mut decoded_validity = BitBufferMut::with_capacity(length); + + for (end, (value, is_valid)) in run_ends.zip(values.iter().zip(validity_mask.iter())) { + let run_len = end - decoded.len(); + if is_valid { + decoded_validity.append_n(true, run_len); + decoded.append_n(value, run_len); + } else { + decoded_validity.append_n(false, run_len); + decoded.append_n(false, run_len); + } + } + + BoolArray::new(decoded.freeze(), Validity::from(decoded_validity.freeze())) +} + +#[cfg(test)] +mod tests { + use vortex_array::ToCanonical; + use vortex_array::arrays::BoolArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::assert_arrays_eq; + use vortex_array::validity::Validity; + use vortex_buffer::BitBuffer; + use vortex_error::VortexResult; + + use super::runend_decode_bools; + + #[test] + fn decode_bools_alternating() -> VortexResult<()> { + // Alternating true/false: [T, T, F, F, F, T, T, T, T, T] + let ends = PrimitiveArray::from_iter([2u32, 5, 10]); + let values = BoolArray::from(BitBuffer::from(vec![true, false, true])); + let decoded = runend_decode_bools(ends, values, 0, 10)?; + + let expected = BoolArray::from(BitBuffer::from(vec![ + true, true, false, false, false, true, true, true, true, true, + ])); + assert_arrays_eq!(decoded, expected); + Ok(()) + } + + #[test] + fn decode_bools_mostly_true() -> VortexResult<()> { + // Mostly true: [T, T, T, T, T, F, T, T, T, T] + let ends = PrimitiveArray::from_iter([5u32, 6, 10]); + let values = BoolArray::from(BitBuffer::from(vec![true, false, true])); + let decoded = runend_decode_bools(ends, values, 0, 10)?; + + let expected = BoolArray::from(BitBuffer::from(vec![ + true, true, true, true, true, false, true, true, true, true, + ])); + assert_arrays_eq!(decoded, expected); + Ok(()) + } + + #[test] + fn decode_bools_mostly_false() -> VortexResult<()> { + // Mostly false: [F, F, F, F, F, T, F, F, F, F] + let ends = PrimitiveArray::from_iter([5u32, 6, 10]); + let values = BoolArray::from(BitBuffer::from(vec![false, true, false])); + let decoded = runend_decode_bools(ends, values, 0, 10)?; + + let expected = BoolArray::from(BitBuffer::from(vec![ + false, false, false, false, false, true, false, false, false, false, + ])); + assert_arrays_eq!(decoded, expected); + Ok(()) + } + + #[test] + fn decode_bools_all_true_single_run() -> VortexResult<()> { + let ends = PrimitiveArray::from_iter([10u32]); + let values = BoolArray::from(BitBuffer::from(vec![true])); + let decoded = runend_decode_bools(ends, values, 0, 10)?; + + let expected = BoolArray::from(BitBuffer::from(vec![ + true, true, true, true, true, true, true, true, true, true, + ])); + assert_arrays_eq!(decoded, expected); + Ok(()) + } + + #[test] + fn decode_bools_all_false_single_run() -> VortexResult<()> { + let ends = PrimitiveArray::from_iter([10u32]); + let values = BoolArray::from(BitBuffer::from(vec![false])); + let decoded = runend_decode_bools(ends, values, 0, 10)?; + + let expected = BoolArray::from(BitBuffer::from(vec![ + false, false, false, false, false, false, false, false, false, false, + ])); + assert_arrays_eq!(decoded, expected); + Ok(()) + } + + #[test] + fn decode_bools_with_offset() -> VortexResult<()> { + // Test with offset: [T, T, F, F, F, T, T, T, T, T] -> slice [2..8] = [F, F, F, T, T, T] + let ends = PrimitiveArray::from_iter([2u32, 5, 10]); + let values = BoolArray::from(BitBuffer::from(vec![true, false, true])); + let decoded = runend_decode_bools(ends, values, 2, 6)?; + + let expected = + BoolArray::from(BitBuffer::from(vec![false, false, false, true, true, true])); + assert_arrays_eq!(decoded, expected); + Ok(()) + } + + #[test] + fn decode_bools_nullable() -> VortexResult<()> { + use vortex_array::validity::Validity; + + // 3 runs: T (valid), F (null), T (valid) -> [T, T, null, null, null, T, T, T, T, T] + let ends = PrimitiveArray::from_iter([2u32, 5, 10]); + let values = BoolArray::new( + BitBuffer::from(vec![true, false, true]), + Validity::from(BitBuffer::from(vec![true, false, true])), + ); + let decoded = runend_decode_bools(ends, values, 0, 10)?; + + // Expected: values=[T, T, F, F, F, T, T, T, T, T], validity=[1, 1, 0, 0, 0, 1, 1, 1, 1, 1] + let expected = BoolArray::new( + BitBuffer::from(vec![ + true, true, false, false, false, true, true, true, true, true, + ]), + Validity::from(BitBuffer::from(vec![ + true, true, false, false, false, true, true, true, true, true, + ])), + ); + assert_arrays_eq!(decoded, expected); + Ok(()) + } + + #[test] + fn decode_bools_nullable_few_runs() -> VortexResult<()> { + // Test few runs (uses fast path): 5 runs of length 2000 each + let ends = PrimitiveArray::from_iter([2000u32, 4000, 6000, 8000, 10000]); + let values = BoolArray::new( + BitBuffer::from(vec![true, false, true, false, true]), + Validity::from(BitBuffer::from(vec![true, false, true, false, true])), + ); + let decoded = runend_decode_bools(ends, values, 0, 10000)?.to_bool(); + + // Check length and a few values + assert_eq!(decoded.len(), 10000); + // First run: valid true + assert!(decoded.validity_mask()?.value(0)); + assert!(decoded.to_bit_buffer().value(0)); + // Second run: null (validity false) + assert!(!decoded.validity_mask()?.value(2000)); + // Third run: valid true + assert!(decoded.validity_mask()?.value(4000)); + assert!(decoded.to_bit_buffer().value(4000)); + Ok(()) + } +} diff --git a/encodings/runend/src/lib.rs b/encodings/runend/src/lib.rs index bf289a23823..e83d77e287c 100644 --- a/encodings/runend/src/lib.rs +++ b/encodings/runend/src/lib.rs @@ -13,6 +13,7 @@ mod array; mod arrow; pub mod compress; mod compute; +pub mod decompress_bool; mod iter; mod kernel; mod ops; diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index b81635bdf5d..b80a00d66fa 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -50,6 +50,9 @@ vortex-file = { workspace = true, optional = true } tokio = { workspace = true, features = ["rt", "macros"], optional = true } vortex-cuda = { workspace = true, optional = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + [lints] workspace = true diff --git a/fuzz/fuzz_targets/array_ops.rs b/fuzz/fuzz_targets/array_ops.rs index e474df342b2..dded332b64c 100644 --- a/fuzz/fuzz_targets/array_ops.rs +++ b/fuzz/fuzz_targets/array_ops.rs @@ -4,18 +4,32 @@ #![no_main] #![allow(clippy::unwrap_used, clippy::result_large_err)] +use std::str::FromStr; + use libfuzzer_sys::Corpus; use libfuzzer_sys::fuzz_target; +use tracing::level_filters::LevelFilter; use vortex_error::vortex_panic; use vortex_fuzz::FuzzArrayAction; use vortex_fuzz::run_fuzz_action; -fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus { +fuzz_target!( + init: { + let fmt = tracing_subscriber::fmt::format() + .with_ansi(false) // Colour output is messed up in raw logs + .without_time() // We run fuzzer in CI which prepends timestamps + .compact(); + let level = std::env::var("RUST_LOG").map( + |v| LevelFilter::from_str(v.as_str()).unwrap()).unwrap_or(LevelFilter::INFO); + tracing_subscriber::fmt() + .event_format(fmt) + .with_max_level(level) + .init(); + }, + |fuzz_action: FuzzArrayAction| -> Corpus { match run_fuzz_action(fuzz_action) { Ok(true) => Corpus::Keep, Ok(false) => Corpus::Reject, - Err(e) => { - vortex_panic!("{e}"); - } + Err(e) => vortex_panic!("{e}"), } }); diff --git a/fuzz/src/array/mask.rs b/fuzz/src/array/mask.rs index 692a13cd35d..ac82caaaeb4 100644 --- a/fuzz/src/array/mask.rs +++ b/fuzz/src/array/mask.rs @@ -23,28 +23,31 @@ use vortex_error::VortexResult; use vortex_mask::AllOr; use vortex_mask::Mask; -/// Set to false any entries for which the mask is true. -/// +/// Apply a logical AND of a validity and a mask. +/// This needs to be coherent with applications of Mask. /// The result is always nullable. The result has the same length as self. #[inline] pub fn mask_validity(validity: &Validity, mask: &Mask) -> Validity { - match mask.bit_buffer() { - AllOr::All => Validity::AllInvalid, - AllOr::None => validity.clone().into_nullable(), - AllOr::Some(make_invalid) => match validity { + let out = match mask.bit_buffer() { + AllOr::All => validity.clone().into_nullable(), + AllOr::None => Validity::AllInvalid, + AllOr::Some(make_valid) => match validity { + Validity::AllInvalid => Validity::AllInvalid, Validity::NonNullable | Validity::AllValid => { - Validity::from_bit_buffer(!make_invalid, Nullability::Nullable) + Validity::from_bit_buffer(make_valid.clone(), Nullability::Nullable) } - Validity::AllInvalid => Validity::AllInvalid, Validity::Array(is_valid) => { let is_valid = is_valid.to_bool(); Validity::from_bit_buffer( - is_valid.to_bit_buffer() & !make_invalid, + is_valid.to_bit_buffer() & make_valid, Nullability::Nullable, ) } }, - } + }; + + tracing::debug!(validity = ?validity, mask = ?mask, out = ?out, "generated fuzzer mask"); + out } /// Apply mask on the canonical form of the array to get a consistent baseline. @@ -125,7 +128,7 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult { // Recursively mask the storage array - let masked_storage = mask_canonical_array(array.storage().to_canonical()?, mask) + let masked_storage = mask_canonical_array(array.storage_array().to_canonical()?, mask) .vortex_expect("mask_canonical_array should succeed in fuzz test"); let ext_dtype = array @@ -173,7 +176,7 @@ mod tests { #[test] fn test_mask_bool_array() { let array = BoolArray::from_iter([true, false, true, false, true]); - let mask = Mask::from_iter([true, false, false, true, false]); + let mask = Mask::from_iter([false, true, true, false, true]); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -184,7 +187,7 @@ mod tests { #[test] fn test_mask_primitive_array() { let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]); - let mask = Mask::from_iter([false, true, false, true, false]); + let mask = Mask::from_iter([true, false, true, false, true]); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -195,7 +198,7 @@ mod tests { #[test] fn test_mask_primitive_array_with_nulls() { let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]); - let mask = Mask::from_iter([true, false, false, true, false]); + let mask = Mask::from_iter([false, true, true, false, true]); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -210,7 +213,7 @@ mod tests { [Some(1i128), Some(2), Some(3), Some(4), Some(5)], dtype, ); - let mask = Mask::from_iter([false, false, true, false, false]); + let mask = Mask::from_iter([true, true, false, true, true]); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -222,7 +225,7 @@ mod tests { #[test] fn test_mask_varbinview_array() { let array = VarBinViewArray::from_iter_str(["one", "two", "three", "four", "five"]); - let mask = Mask::from_iter([true, false, true, false, true]); + let mask = Mask::from_iter([false, true, false, true, false]); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -241,7 +244,7 @@ mod tests { .with_zero_copy_to_list(true) }; - let mask = Mask::from_iter([false, true, false]); + let mask = Mask::from_iter([true, false, true]); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -257,7 +260,7 @@ mod tests { let array = FixedSizeListArray::try_new(elements, 2, Nullability::NonNullable.into(), 3).unwrap(); - let mask = Mask::from_iter([true, false, true]); + let mask = Mask::from_iter([false, true, false]); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -281,7 +284,7 @@ mod tests { ) .unwrap(); - let mask = Mask::from_iter([false, true, false]); + let mask = Mask::from_iter([true, false, true]); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -292,9 +295,9 @@ mod tests { } #[test] - fn test_mask_all_true() { + fn test_mask_all_false() { let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]); - let mask = Mask::AllTrue(5); + let mask = Mask::AllFalse(5); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -303,9 +306,9 @@ mod tests { } #[test] - fn test_mask_all_false() { + fn test_mask_all_true() { let array = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]); - let mask = Mask::AllFalse(5); + let mask = Mask::AllTrue(5); let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); @@ -317,10 +320,9 @@ mod tests { #[test] fn test_mask_empty_array() { let array = PrimitiveArray::from_iter(Vec::::new()); - let mask = Mask::AllFalse(0); - - let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); - - assert_eq!(result.len(), 0); + for mask in [Mask::AllFalse(0), Mask::AllTrue(0)] { + let result = mask_canonical_array(array.to_canonical().unwrap(), &mask).unwrap(); + assert_eq!(result.len(), 0); + } } } diff --git a/fuzz/src/array/mod.rs b/fuzz/src/array/mod.rs index acd461c3a32..dbeebaa592a 100644 --- a/fuzz/src/array/mod.rs +++ b/fuzz/src/array/mod.rs @@ -37,6 +37,7 @@ use strum::EnumIter; use strum::IntoEnumIterator; pub(crate) use sum::*; pub(crate) use take::*; +use tracing::debug; use vortex_array::ArrayRef; use vortex_array::DynArray; use vortex_array::IntoArray; @@ -560,7 +561,14 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> crate::error::VortexFuzz let FuzzArrayAction { array, actions } = fuzz_action; let mut current_array = array.to_array(); + debug!( + "Initial array:\nTree:\n{}Values:\n{:#}", + current_array.display_tree(), + current_array.display_values() + ); + for (i, (action, expected)) in actions.into_iter().enumerate() { + debug!(id = i, action = ?action); match action { Action::Compress(strategy) => { let canonical = current_array diff --git a/fuzz/src/array/scalar_at.rs b/fuzz/src/array/scalar_at.rs index 4d43ec7aac4..8c28f975675 100644 --- a/fuzz/src/array/scalar_at.rs +++ b/fuzz/src/array/scalar_at.rs @@ -91,7 +91,8 @@ pub fn scalar_at_canonical_array(canonical: Canonical, index: usize) -> VortexRe Scalar::struct_(array.dtype().clone(), field_scalars) } Canonical::Extension(array) => { - let storage_scalar = scalar_at_canonical_array(array.storage().to_canonical()?, index)?; + let storage_scalar = + scalar_at_canonical_array(array.storage_array().to_canonical()?, index)?; Scalar::extension_ref(array.ext_dtype().clone(), storage_scalar) } }) diff --git a/fuzz/src/error.rs b/fuzz/src/error.rs index 71e6589f895..86f1be5720f 100644 --- a/fuzz/src/error.rs +++ b/fuzz/src/error.rs @@ -97,12 +97,29 @@ impl Display for VortexFuzzError { "MinMax mismatch: expected {lhs:?} got {rhs:?} in step {step}\nBacktrace:\n{backtrace}" ) } - VortexFuzzError::ArrayNotEqual(expected, actual, idx, lhs, rhs, step, backtrace) => { + VortexFuzzError::ArrayNotEqual( + expected_scalar, + actual_scalar, + idx, + expected_array, + current_array, + step, + backtrace, + ) => { + let expected_tree = expected_array.display_tree(); + let current_tree = current_array.display_tree(); + let expected_values = expected_array.display_values(); + let current_values = current_array.display_values(); write!( f, - "{expected} != {actual} at index {idx}, lhs is {} rhs is {} in step {step}\nBacktrace:\n{backtrace}", - lhs.display_tree(), - rhs.display_tree(), + "Mismatch at step {step} at index {idx}\n\ + Expected scalar:\n{expected_scalar}\n\ + Actual scalar:\n{actual_scalar}\n\ + Expected tree:\n{expected_tree}\n\ + Current tree:\n{current_tree}\ + Expected values:\n{expected_values:#}\n\ + Current values:\n{current_values:#}\ + \n{backtrace}" ) } VortexFuzzError::DTypeMismatch(lhs, rhs, step, backtrace) => { diff --git a/scripts/compare-benchmark-jsons.py b/scripts/compare-benchmark-jsons.py index 1802325c4cd..eea999315ad 100644 --- a/scripts/compare-benchmark-jsons.py +++ b/scripts/compare-benchmark-jsons.py @@ -65,13 +65,19 @@ def extract_dataset_key(df): # Generate summary statistics df3["ratio"] = df3["value_pr"] / df3["value_base"] -df3["remark"] = pd.Series([""] * len(df3)) -df3["remark"] = df3["remark"].case_when( - [ - (df3["ratio"] >= regression_threshold, "🚨"), - (df3["ratio"] <= improvement_threshold, "🚀"), - ] -) + + +def extract_engine_and_file_format(name): + if not isinstance(name, str) or "/" not in name or ":" not in name: + return pd.Series({"engine": "unknown", "file_format": "unknown"}) + + target = name.rsplit("/", 1)[-1] + engine, file_format = target.split(":", 1) + return pd.Series({"engine": engine, "file_format": file_format}) + + +df3[["engine", "file_format"]] = df3["name"].apply(extract_engine_and_file_format) + # Filter for different target combinations for summary statistics vortex_df = df3[df3["name"].str.contains("vortex", case=False, na=False)] @@ -146,8 +152,6 @@ def format_performance(ratio, target_name): overall_performance = "no data" if pd.isna(geo_mean_ratio) else format_performance(geo_mean_ratio, "overall") vortex_performance = format_performance(vortex_geo_mean_ratio, "vortex") -duckdb_vortex_performance = format_performance(duckdb_vortex_geo_mean_ratio, "duckdb:vortex") -datafusion_vortex_performance = format_performance(datafusion_vortex_geo_mean_ratio, "datafusion:vortex") parquet_performance = format_performance(parquet_geo_mean_ratio, "parquet") @@ -164,41 +168,97 @@ def format_performance(ratio, target_name): if len(parquet_df) > 0: summary_lines.extend([f"- **Parquet**: {parquet_performance}"]) -# Only add duckdb:vortex section if we have that data -if len(duckdb_vortex_df) > 0: - summary_lines.append(f"- **duckdb:vortex**: {duckdb_vortex_performance}") - -# Only add datafusion:vortex section if we have that data -if len(datafusion_vortex_df) > 0: - summary_lines.append(f"- **datafusion:vortex**: {datafusion_vortex_performance}") -# Only add best/worst if we have vortex data -if len(vortex_df) > 0: - summary_lines.extend( - [ - f"- **Best**: {best_improvement}", - f"- **Worst**: {worst_regression}", - f"- **Significant (>{threshold_pct}%)**: {significant_improvements}↑ {significant_regressions}↓", - ] +ENGINE_ORDER = { + "vortex": 0, + "datafusion": 1, + "duckdb": 2, + "lance": 3, + "arrow": 4, +} + +FILE_FORMAT_ORDER = { + "vortex-file-compressed": 0, + "vortex-compact": 1, + "parquet": 2, + "lance": 3, + "duckdb": 4, + "arrow": 5, +} + + +def group_sort_key(group_key): + engine, file_format = group_key + return ( + ENGINE_ORDER.get(engine, len(ENGINE_ORDER)), + FILE_FORMAT_ORDER.get(file_format, len(FILE_FORMAT_ORDER)), + engine, + file_format, ) -# Build table -table_df = pd.DataFrame( - { - "name": df3["name"], - f"PR {pr_commit_id[:8]}": df3["value_pr"], - f"base {base_commit_id[:8]}": df3["value_base"], - "ratio (PR/base)": df3["ratio"], - "unit": df3["unit_base"], - "remark": df3["remark"], - } -) + +def build_group_summary(group_df): + geo_mean_ratio = calculate_geo_mean(group_df) + ratio_summary = format_performance(geo_mean_ratio, "group") + + significant_improvements = (group_df["ratio"] < improvement_threshold).sum() + significant_regressions = (group_df["ratio"] > regression_threshold).sum() + + return ratio_summary, significant_improvements, significant_regressions + + +def format_integer_value(value): + if pd.isna(value): + return "" + + return str(int(value)) + + +def format_name_with_highlight(name, ratio): + if pd.isna(ratio): + return name + + if ratio <= improvement_threshold: + return f"🚀 {name}" + + if ratio >= regression_threshold: + return f"🚨 {name}" + + return name + # Output complete formatted markdown print("\n".join(summary_lines)) print("") -print("
") -print("Detailed Results Table") -print("") -print(table_df.to_markdown(index=False, tablefmt="github", floatfmt=".2f")) -print("
") +grouped_tables = df3.groupby(["engine", "file_format"], dropna=False, sort=False) +for engine, file_format in sorted(grouped_tables.groups.keys(), key=group_sort_key): + group_df = grouped_tables.get_group((engine, file_format)).sort_values("name") + group_performance, significant_improvements, significant_regressions = build_group_summary(group_df) + unit = group_df["unit_base"].dropna().iloc[0] if group_df["unit_base"].notna().any() else "unit" + display_df = pd.DataFrame( + { + "name": [ + format_name_with_highlight(name, ratio) for name, ratio in zip(group_df["name"], group_df["ratio"]) + ], + f"PR {pr_commit_id[:8]} ({unit})": group_df["value_pr"].map(format_integer_value), + f"base {base_commit_id[:8]} ({unit})": group_df["value_base"].map(format_integer_value), + "ratio (PR/base)": group_df["ratio"], + } + ) + print("
") + summary_text = ( + f"{engine} / {file_format} ({group_performance}, {significant_improvements}↑ {significant_regressions}↓)" + ) + print(f"{summary_text}") + print("") + print("
") + print("") + print( + display_df.to_markdown( + index=False, + tablefmt="github", + floatfmt=".2f", + ) + ) + print("") + print("
") diff --git a/scripts/setup-benchmark.sh b/scripts/setup-benchmark.sh new file mode 100755 index 00000000000..d45dad2e7c5 --- /dev/null +++ b/scripts/setup-benchmark.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright the Vortex contributors + +set -Eeu -o pipefail -x + +if [ "$EUID" -ne 0 ]; then + echo "Environment setup script for benchmarks should run as root." + exit 0 +fi + +# Turn off frequency scaling +for gov in /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor; do + echo performance > "$gov" 2>/dev/null || true +done + +# Really discourage swapping to disk +sysctl vm.swappiness=0 + +# Disable ASLR - https://docs.kernel.org/admin-guide/sysctl/kernel.html#randomize-va-space +sysctl kernel.randomize_va_space=0 + +# Reduce kernel logging to minimum +dmesg -n 1 + +# Disable some unused services and features +systemctl stop apparmor ModemManager +systemctl disable apparmor ModemManager + +# mask prevents them from being started by other services +systemctl mask ModemManager + +# For apparmor specifically, also teardown loaded profiles +aa-teardown diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index e915c8e2002..575a41459b9 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -67,6 +67,7 @@ tabled = { workspace = true, optional = true, default-features = false, features ] } termtree = { workspace = true } tracing = { workspace = true } +uuid = { workspace = true } vortex-buffer = { workspace = true, features = ["arrow"] } vortex-error = { workspace = true, features = ["flatbuffers"] } vortex-flatbuffers = { workspace = true, features = ["array", "dtype"] } diff --git a/vortex-array/benches/compare.rs b/vortex-array/benches/compare.rs index d60d972c850..1194e6cf97a 100644 --- a/vortex-array/benches/compare.rs +++ b/vortex-array/benches/compare.rs @@ -21,7 +21,7 @@ fn main() { divan::main(); } -const ARRAY_SIZE: usize = 10_000_000; +const ARRAY_SIZE: usize = 65_536; #[divan::bench] fn compare_bool(bencher: Bencher) { diff --git a/vortex-array/benches/search_sorted.rs b/vortex-array/benches/search_sorted.rs index 353dec65a3b..86cdf774645 100644 --- a/vortex-array/benches/search_sorted.rs +++ b/vortex-array/benches/search_sorted.rs @@ -33,8 +33,8 @@ fn binary_search_vortex(bencher: Bencher) { fn fixture() -> (Vec, i32) { let mut rng = StdRng::seed_from_u64(0); - let range = Uniform::new(0, 1_000_000).unwrap(); - let mut data: Vec = (0..1_000_000).map(|_| rng.sample(range)).collect(); + let range = Uniform::new(0, 65_536).unwrap(); + let mut data: Vec = (0..65_536).map(|_| rng.sample(range)).collect(); data.sort(); (data, rng.sample(range)) diff --git a/vortex-array/benches/take_fsl.rs b/vortex-array/benches/take_fsl.rs index cea54c7200e..383ef192bcb 100644 --- a/vortex-array/benches/take_fsl.rs +++ b/vortex-array/benches/take_fsl.rs @@ -28,7 +28,7 @@ fn main() { } /// Number of lists in the source array. -const NUM_LISTS: usize = 10_000; +const NUM_LISTS: usize = 500; /// Number of indices to take. const NUM_INDICES: &[usize] = &[100, 1_000]; diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 26ed0b44adf..8b822467195 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -58,7 +58,7 @@ pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggr pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial -pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> @@ -286,7 +286,7 @@ pub type vortex_array::aggregate_fn::AggregateFnVTable::Options: 'static + core: pub type vortex_array::aggregate_fn::AggregateFnVTable::Partial: 'static + core::marker::Send -pub fn vortex_array::aggregate_fn::AggregateFnVTable::accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::Canonical, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::AggregateFnVTable::accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::Columnar, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::AggregateFnVTable::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> @@ -316,7 +316,7 @@ pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggr pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial -pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> @@ -1686,11 +1686,13 @@ impl vortex_array::arrays::ExtensionArray pub fn vortex_array::arrays::ExtensionArray::ext_dtype(&self) -> &vortex_array::dtype::extension::ExtDTypeRef -pub fn vortex_array::arrays::ExtensionArray::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::arrays::ExtensionArray::new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> Self -pub fn vortex_array::arrays::ExtensionArray::new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage: vortex_array::ArrayRef) -> Self +pub unsafe fn vortex_array::arrays::ExtensionArray::new_unchecked(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> Self -pub fn vortex_array::arrays::ExtensionArray::storage(&self) -> &vortex_array::ArrayRef +pub fn vortex_array::arrays::ExtensionArray::storage_array(&self) -> &vortex_array::ArrayRef + +pub fn vortex_array::arrays::ExtensionArray::try_new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> vortex_error::VortexResult impl vortex_array::arrays::ExtensionArray @@ -5498,11 +5500,13 @@ impl vortex_array::arrays::ExtensionArray pub fn vortex_array::arrays::ExtensionArray::ext_dtype(&self) -> &vortex_array::dtype::extension::ExtDTypeRef -pub fn vortex_array::arrays::ExtensionArray::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::arrays::ExtensionArray::new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> Self + +pub unsafe fn vortex_array::arrays::ExtensionArray::new_unchecked(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> Self -pub fn vortex_array::arrays::ExtensionArray::new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage: vortex_array::ArrayRef) -> Self +pub fn vortex_array::arrays::ExtensionArray::storage_array(&self) -> &vortex_array::ArrayRef -pub fn vortex_array::arrays::ExtensionArray::storage(&self) -> &vortex_array::ArrayRef +pub fn vortex_array::arrays::ExtensionArray::try_new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> vortex_error::VortexResult impl vortex_array::arrays::ExtensionArray @@ -10114,7 +10118,7 @@ pub fn vortex_array::dtype::PType::try_from_arrow(value: &arrow_schema::datatype pub mod vortex_array::dtype::extension -pub struct vortex_array::dtype::extension::ExtDType(_) +pub struct vortex_array::dtype::extension::ExtDType impl vortex_array::dtype::extension::ExtDType @@ -10134,21 +10138,21 @@ pub fn vortex_array::dtype::extension::ExtDType::try_with_vtable(vtable: V, m pub fn vortex_array::dtype::extension::ExtDType::vtable(&self) -> &V -impl core::clone::Clone for vortex_array::dtype::extension::ExtDType +impl core::clone::Clone for vortex_array::dtype::extension::ExtDType where ::Metadata: core::clone::Clone pub fn vortex_array::dtype::extension::ExtDType::clone(&self) -> vortex_array::dtype::extension::ExtDType -impl core::cmp::Eq for vortex_array::dtype::extension::ExtDType +impl core::cmp::Eq for vortex_array::dtype::extension::ExtDType where ::Metadata: core::cmp::Eq -impl core::cmp::PartialEq for vortex_array::dtype::extension::ExtDType +impl core::cmp::PartialEq for vortex_array::dtype::extension::ExtDType where ::Metadata: core::cmp::PartialEq pub fn vortex_array::dtype::extension::ExtDType::eq(&self, other: &vortex_array::dtype::extension::ExtDType) -> bool -impl core::fmt::Debug for vortex_array::dtype::extension::ExtDType +impl core::fmt::Debug for vortex_array::dtype::extension::ExtDType where ::Metadata: core::fmt::Debug pub fn vortex_array::dtype::extension::ExtDType::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl core::hash::Hash for vortex_array::dtype::extension::ExtDType +impl core::hash::Hash for vortex_array::dtype::extension::ExtDType where ::Metadata: core::hash::Hash pub fn vortex_array::dtype::extension::ExtDType::hash<__H: core::hash::Hasher>(&self, state: &mut __H) @@ -10176,7 +10180,7 @@ pub fn vortex_array::dtype::extension::ExtDTypeRef::with_nullability(&self, null impl vortex_array::dtype::extension::ExtDTypeRef -pub fn vortex_array::dtype::extension::ExtDTypeRef::downcast(self) -> vortex_array::dtype::extension::ExtDType +pub fn vortex_array::dtype::extension::ExtDTypeRef::downcast(self) -> alloc::sync::Arc> pub fn vortex_array::dtype::extension::ExtDTypeRef::is(&self) -> bool @@ -10184,7 +10188,7 @@ pub fn vortex_array::dtype::extension::ExtDTypeRef::metadata(&self) -> core::option::Option<::Match> -pub fn vortex_array::dtype::extension::ExtDTypeRef::try_downcast(self) -> core::result::Result, vortex_array::dtype::extension::ExtDTypeRef> +pub fn vortex_array::dtype::extension::ExtDTypeRef::try_downcast(self) -> core::result::Result>, vortex_array::dtype::extension::ExtDTypeRef> impl core::clone::Clone for vortex_array::dtype::extension::ExtDTypeRef @@ -10232,11 +10236,11 @@ pub fn vortex_array::dtype::extension::ExtVTable::id(&self) -> vortex_array::dty pub fn vortex_array::dtype::extension::ExtVTable::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> -pub fn vortex_array::dtype::extension::ExtVTable::unpack_native<'a>(&self, metadata: &'a Self::Metadata, storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult +pub fn vortex_array::dtype::extension::ExtVTable::unpack_native<'a>(&self, ext_dtype: &'a vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult -pub fn vortex_array::dtype::extension::ExtVTable::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> +pub fn vortex_array::dtype::extension::ExtVTable::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType) -> vortex_error::VortexResult<()> -pub fn vortex_array::dtype::extension::ExtVTable::validate_scalar_value(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> +pub fn vortex_array::dtype::extension::ExtVTable::validate_scalar_value(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::datetime::Date @@ -10250,11 +10254,11 @@ pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> -pub fn vortex_array::extension::datetime::Date::unpack_native(&self, metadata: &Self::Metadata, _storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Date::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult -pub fn vortex_array::extension::datetime::Date::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Date::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType) -> vortex_error::VortexResult<()> -pub fn vortex_array::extension::datetime::Date::validate_scalar_value(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Date::validate_scalar_value(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::datetime::Time @@ -10268,11 +10272,11 @@ pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> -pub fn vortex_array::extension::datetime::Time::unpack_native(&self, metadata: &Self::Metadata, _storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Time::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult -pub fn vortex_array::extension::datetime::Time::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Time::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType) -> vortex_error::VortexResult<()> -pub fn vortex_array::extension::datetime::Time::validate_scalar_value(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Time::validate_scalar_value(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::datetime::Timestamp @@ -10286,11 +10290,29 @@ pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array:: pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> -pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, metadata: &'a Self::Metadata, _storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, ext_dtype: &'a vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_array::extension::datetime::Timestamp::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType) -> vortex_error::VortexResult<()> + +pub fn vortex_array::extension::datetime::Timestamp::validate_scalar_value(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> + +impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::uuid::Uuid + +pub type vortex_array::extension::uuid::Uuid::Metadata = vortex_array::extension::uuid::UuidMetadata + +pub type vortex_array::extension::uuid::Uuid::NativeValue<'a> = uuid::Uuid -pub fn vortex_array::extension::datetime::Timestamp::validate_dtype(&self, _metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult -pub fn vortex_array::extension::datetime::Timestamp::validate_scalar_value(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_array::extension::uuid::Uuid::unpack_native<'a>(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_array::extension::uuid::Uuid::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType) -> vortex_error::VortexResult<()> + +pub fn vortex_array::extension::uuid::Uuid::validate_scalar_value(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> pub trait vortex_array::dtype::extension::Matcher @@ -14118,11 +14140,11 @@ pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> -pub fn vortex_array::extension::datetime::Date::unpack_native(&self, metadata: &Self::Metadata, _storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Date::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult -pub fn vortex_array::extension::datetime::Date::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Date::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType) -> vortex_error::VortexResult<()> -pub fn vortex_array::extension::datetime::Date::validate_scalar_value(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Date::validate_scalar_value(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> pub struct vortex_array::extension::datetime::Time @@ -14168,11 +14190,11 @@ pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> -pub fn vortex_array::extension::datetime::Time::unpack_native(&self, metadata: &Self::Metadata, _storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Time::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult -pub fn vortex_array::extension::datetime::Time::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Time::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType) -> vortex_error::VortexResult<()> -pub fn vortex_array::extension::datetime::Time::validate_scalar_value(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Time::validate_scalar_value(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> pub struct vortex_array::extension::datetime::Timestamp @@ -14220,11 +14242,11 @@ pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array:: pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> -pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, metadata: &'a Self::Metadata, _storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, ext_dtype: &'a vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult -pub fn vortex_array::extension::datetime::Timestamp::validate_dtype(&self, _metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Timestamp::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType) -> vortex_error::VortexResult<()> -pub fn vortex_array::extension::datetime::Timestamp::validate_scalar_value(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> +pub fn vortex_array::extension::datetime::Timestamp::validate_scalar_value(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> pub struct vortex_array::extension::datetime::TimestampOptions @@ -14256,6 +14278,82 @@ pub fn vortex_array::extension::datetime::TimestampOptions::hash<__H: core::hash impl core::marker::StructuralPartialEq for vortex_array::extension::datetime::TimestampOptions +pub mod vortex_array::extension::uuid + +pub struct vortex_array::extension::uuid::Uuid + +impl core::clone::Clone for vortex_array::extension::uuid::Uuid + +pub fn vortex_array::extension::uuid::Uuid::clone(&self) -> vortex_array::extension::uuid::Uuid + +impl core::cmp::Eq for vortex_array::extension::uuid::Uuid + +impl core::cmp::PartialEq for vortex_array::extension::uuid::Uuid + +pub fn vortex_array::extension::uuid::Uuid::eq(&self, other: &vortex_array::extension::uuid::Uuid) -> bool + +impl core::default::Default for vortex_array::extension::uuid::Uuid + +pub fn vortex_array::extension::uuid::Uuid::default() -> vortex_array::extension::uuid::Uuid + +impl core::fmt::Debug for vortex_array::extension::uuid::Uuid + +pub fn vortex_array::extension::uuid::Uuid::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::extension::uuid::Uuid + +pub fn vortex_array::extension::uuid::Uuid::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_array::extension::uuid::Uuid + +impl vortex_array::dtype::extension::ExtVTable for vortex_array::extension::uuid::Uuid + +pub type vortex_array::extension::uuid::Uuid::Metadata = vortex_array::extension::uuid::UuidMetadata + +pub type vortex_array::extension::uuid::Uuid::NativeValue<'a> = uuid::Uuid + +pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult + +pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_array::extension::uuid::Uuid::unpack_native<'a>(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_array::extension::uuid::Uuid::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType) -> vortex_error::VortexResult<()> + +pub fn vortex_array::extension::uuid::Uuid::validate_scalar_value(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult<()> + +pub struct vortex_array::extension::uuid::UuidMetadata + +pub vortex_array::extension::uuid::UuidMetadata::version: core::option::Option + +impl core::clone::Clone for vortex_array::extension::uuid::UuidMetadata + +pub fn vortex_array::extension::uuid::UuidMetadata::clone(&self) -> vortex_array::extension::uuid::UuidMetadata + +impl core::cmp::Eq for vortex_array::extension::uuid::UuidMetadata + +impl core::cmp::PartialEq for vortex_array::extension::uuid::UuidMetadata + +pub fn vortex_array::extension::uuid::UuidMetadata::eq(&self, other: &Self) -> bool + +impl core::default::Default for vortex_array::extension::uuid::UuidMetadata + +pub fn vortex_array::extension::uuid::UuidMetadata::default() -> vortex_array::extension::uuid::UuidMetadata + +impl core::fmt::Debug for vortex_array::extension::uuid::UuidMetadata + +pub fn vortex_array::extension::uuid::UuidMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::extension::uuid::UuidMetadata + +pub fn vortex_array::extension::uuid::UuidMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::extension::uuid::UuidMetadata + +pub fn vortex_array::extension::uuid::UuidMetadata::hash(&self, state: &mut H) + pub struct vortex_array::extension::EmptyMetadata impl core::clone::Clone for vortex_array::extension::EmptyMetadata @@ -18538,7 +18636,7 @@ pub fn vortex_array::scalar_fn::EmptyOptions::hash<__H: core::hash::Hasher>(&sel impl core::marker::StructuralPartialEq for vortex_array::scalar_fn::EmptyOptions -pub struct vortex_array::scalar_fn::ScalarFn(_) +pub struct vortex_array::scalar_fn::ScalarFn impl vortex_array::scalar_fn::ScalarFn @@ -18586,6 +18684,10 @@ pub fn vortex_array::scalar_fn::ScalarFnRef::as_(&self) -> core::option::Option<&::Options> +pub fn vortex_array::scalar_fn::ScalarFnRef::downcast(self) -> alloc::sync::Arc> + +pub fn vortex_array::scalar_fn::ScalarFnRef::downcast_ref(&self) -> core::option::Option<&vortex_array::scalar_fn::ScalarFn> + pub fn vortex_array::scalar_fn::ScalarFnRef::execute(&self, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::ScalarFnRef::id(&self) -> vortex_array::scalar_fn::ScalarFnId @@ -18600,9 +18702,9 @@ pub fn vortex_array::scalar_fn::ScalarFnRef::return_dtype(&self, arg_types: &[vo pub fn vortex_array::scalar_fn::ScalarFnRef::signature(&self) -> vortex_array::scalar_fn::ScalarFnSignature<'_> -pub fn vortex_array::scalar_fn::ScalarFnRef::validity(&self, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult +pub fn vortex_array::scalar_fn::ScalarFnRef::try_downcast(self) -> core::result::Result>, vortex_array::scalar_fn::ScalarFnRef> -pub fn vortex_array::scalar_fn::ScalarFnRef::vtable_ref(&self) -> core::option::Option<&V> +pub fn vortex_array::scalar_fn::ScalarFnRef::validity(&self, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult impl core::clone::Clone for vortex_array::scalar_fn::ScalarFnRef diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index d3fab024739..5e9a12e53fb 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -7,7 +7,7 @@ use vortex_session::VortexSession; use crate::AnyCanonical; use crate::ArrayRef; -use crate::Canonical; +use crate::Columnar; use crate::DynArray; use crate::VortexSessionExecute; use crate::aggregate_fn::AggregateFn; @@ -131,11 +131,11 @@ impl DynAccumulator for Accumulator { batch = batch.execute(&mut ctx)?; } - // Otherwise, execute the batch until it is canonical and accumulate it into the state. - let canonical = batch.execute::(&mut ctx)?; + // Otherwise, execute the batch until it is columnar and accumulate it into the state. + let columnar = batch.execute::(&mut ctx)?; self.vtable - .accumulate(&mut self.partial, &canonical, &mut ctx) + .accumulate(&mut self.partial, &columnar, &mut ctx) } fn is_saturated(&self) -> bool { diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index ff85da637b9..17d165e3746 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -14,6 +14,7 @@ use vortex_session::VortexSession; use crate::AnyCanonical; use crate::ArrayRef; use crate::Canonical; +use crate::Columnar; use crate::DynArray; use crate::ExecutionCtx; use crate::IntoArray; @@ -121,7 +122,11 @@ impl DynGroupedAccumulator for GroupedAccumulator { // We first execute the groups until it is a ListView or FixedSizeList, since we only // dispatch the aggregate kernel over the elements of these arrays. - match groups.clone().execute::(&mut ctx)? { + let canonical = match groups.clone().execute::(&mut ctx)? { + Columnar::Canonical(c) => c, + Columnar::Constant(c) => c.into_array().execute::(&mut ctx)?, + }; + match canonical { Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx), Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx), _ => vortex_panic!("We checked the DType above, so this should never happen"), @@ -192,7 +197,7 @@ impl GroupedAccumulator { } // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); + let elements = elements.execute::(ctx)?.into_array(); let offsets = groups.offsets(); let sizes = groups.sizes().cast(offsets.dtype().clone())?; let validity = groups.validity().to_mask(offsets.len()); @@ -279,7 +284,7 @@ impl GroupedAccumulator { } // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); + let elements = elements.execute::(ctx)?.into_array(); let validity = groups.validity().to_mask(groups.len()); let mut accumulator = Accumulator::try_new( diff --git a/vortex-array/src/aggregate_fn/fns/sum.rs b/vortex-array/src/aggregate_fn/fns/sum.rs index 5b16b585e37..52af2f3c4fb 100644 --- a/vortex-array/src/aggregate_fn/fns/sum.rs +++ b/vortex-array/src/aggregate_fn/fns/sum.rs @@ -14,11 +14,13 @@ use vortex_mask::AllOr; use crate::ArrayRef; use crate::Canonical; +use crate::Columnar; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::EmptyOptions; use crate::arrays::BoolArray; +use crate::arrays::ConstantArray; use crate::arrays::DecimalArray; use crate::arrays::PrimitiveArray; use crate::dtype::DType; @@ -149,7 +151,7 @@ impl AggregateFnVTable for Sum { fn accumulate( &self, partial: &mut Self::Partial, - batch: &Canonical, + batch: &Columnar, _ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut inner = match partial.current.take() { @@ -158,10 +160,13 @@ impl AggregateFnVTable for Sum { }; let result = match batch { - Canonical::Primitive(p) => accumulate_primitive(&mut inner, p), - Canonical::Bool(b) => accumulate_bool(&mut inner, b), - Canonical::Decimal(d) => accumulate_decimal(&mut inner, d), - _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()), + Columnar::Canonical(c) => match c { + Canonical::Primitive(p) => accumulate_primitive(&mut inner, p), + Canonical::Bool(b) => accumulate_bool(&mut inner, b), + Canonical::Decimal(d) => accumulate_decimal(&mut inner, d), + _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()), + }, + Columnar::Constant(c) => accumulate_constant(&mut inner, c), }; match result { @@ -349,6 +354,85 @@ fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResult { Ok(checked_add_u64(acc, true_count)) } +/// Accumulate a constant array into the sum state. +/// Computes `scalar * len` and adds to the accumulator. +/// Returns Ok(true) if saturated (overflow), Ok(false) if not. +fn accumulate_constant(inner: &mut SumState, c: &ConstantArray) -> VortexResult { + let scalar = c.scalar(); + if scalar.is_null() || c.is_empty() { + return Ok(false); + } + let len = c.len(); + + match scalar.dtype() { + DType::Bool(_) => { + let SumState::Unsigned(acc) = inner else { + vortex_panic!("expected unsigned sum state for bool input"); + }; + let val = scalar + .as_bool() + .value() + .ok_or_else(|| vortex_err!("Expected non-null bool scalar for sum"))?; + if val { + Ok(checked_add_u64(acc, len as u64)) + } else { + Ok(false) + } + } + DType::Primitive(..) => { + let pvalue = scalar + .as_primitive() + .pvalue() + .ok_or_else(|| vortex_err!("Expected non-null primitive scalar for sum"))?; + match inner { + SumState::Unsigned(acc) => { + let val = pvalue.cast::()?; + match val.checked_mul(len as u64) { + Some(product) => Ok(checked_add_u64(acc, product)), + None => Ok(true), + } + } + SumState::Signed(acc) => { + let val = pvalue.cast::()?; + match i64::try_from(len).ok().and_then(|l| val.checked_mul(l)) { + Some(product) => Ok(checked_add_i64(acc, product)), + None => Ok(true), + } + } + SumState::Float(acc) => { + let val = pvalue.cast::()?; + *acc += val * len as f64; + Ok(false) + } + SumState::Decimal(_) => { + vortex_panic!("decimal sum state with primitive input") + } + } + } + DType::Decimal(..) => { + let SumState::Decimal(acc) = inner else { + vortex_panic!("expected decimal sum state for decimal input"); + }; + let val = scalar + .as_decimal() + .decimal_value() + .ok_or_else(|| vortex_err!("Expected non-null decimal scalar for sum"))?; + let len_decimal = DecimalValue::from(len as i128); + match val.checked_mul(&len_decimal) { + Some(product) => match acc.checked_add(&product) { + Some(r) => { + *acc = r; + Ok(false) + } + None => Ok(true), + }, + None => Ok(true), + } + } + _ => vortex_bail!("Unsupported constant type for sum: {}", scalar.dtype()), + } +} + /// Accumulate a decimal array into the sum state. /// Returns Ok(true) if saturated (overflow), Ok(false) if not. fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult { diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 7c53387b56b..0e45e8a54fd 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -12,7 +12,7 @@ use vortex_error::vortex_bail; use vortex_session::VortexSession; use crate::ArrayRef; -use crate::Canonical; +use crate::Columnar; use crate::DynArray; use crate::ExecutionCtx; use crate::IntoArray; @@ -95,7 +95,7 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { fn accumulate( &self, state: &mut Self::Partial, - batch: &Canonical, + batch: &Columnar, ctx: &mut ExecutionCtx, ) -> VortexResult<()>; diff --git a/vortex-array/src/arrays/datetime/mod.rs b/vortex-array/src/arrays/datetime/mod.rs index b10c4ad4668..e3052147a51 100644 --- a/vortex-array/src/arrays/datetime/mod.rs +++ b/vortex-array/src/arrays/datetime/mod.rs @@ -125,7 +125,7 @@ impl TemporalArray { /// These values are to be interpreted based on the time unit and optional time-zone stored /// in the TemporalMetadata. pub fn temporal_values(&self) -> &ArrayRef { - self.ext.storage() + self.ext.storage_array() } /// Retrieve the temporal metadata. diff --git a/vortex-array/src/arrays/extension/array.rs b/vortex-array/src/arrays/extension/array.rs index 8df5963a5e9..799646817a9 100644 --- a/vortex-array/src/arrays/extension/array.rs +++ b/vortex-array/src/arrays/extension/array.rs @@ -1,10 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_error::VortexExpect; +use vortex_error::VortexResult; + use crate::ArrayRef; use crate::dtype::DType; use crate::dtype::extension::ExtDTypeRef; -use crate::dtype::extension::ExtId; use crate::stats::ArrayStats; /// An extension array that wraps another array with additional type information. @@ -47,38 +49,80 @@ use crate::stats::ArrayStats; /// - Scalar access wraps storage scalars with extension metadata #[derive(Clone, Debug)] pub struct ExtensionArray { + /// The storage dtype. This **must** be a [`Extension::DType`] variant. pub(super) dtype: DType, - pub(super) storage: ArrayRef, + + /// The backing storage array for this extension array. + pub(super) storage_array: ArrayRef, + + /// The stats for this array. pub(super) stats_set: ArrayStats, } impl ExtensionArray { - pub fn new(ext_dtype: ExtDTypeRef, storage: ArrayRef) -> Self { + /// Constructs a new `ExtensionArray`. + /// + /// # Panics + /// + /// Panics if the storage array in not compatible with the extension dtype. + pub fn new(ext_dtype: ExtDTypeRef, storage_array: ArrayRef) -> Self { + Self::try_new(ext_dtype, storage_array).vortex_expect("Failed to create `ExtensionArray`") + } + + /// Tries to construct a new `ExtensionArray`. + /// + /// # Errors + /// + /// Returns an error if the storage array in not compatible with the extension dtype. + pub fn try_new(ext_dtype: ExtDTypeRef, storage_array: ArrayRef) -> VortexResult { + // TODO(connor): Replace these statements once we add `validate_storage_array`. + // ext_dtype.validate_storage_array(&storage_array)?; assert_eq!( ext_dtype.storage_dtype(), - storage.dtype(), + storage_array.dtype(), "ExtensionArray: storage_dtype must match storage array DType", ); + + // SAFETY: we validate that the inputs are valid above. + Ok(unsafe { Self::new_unchecked(ext_dtype, storage_array) }) + } + + /// Creates a new `ExtensionArray`. + /// + /// # Safety + /// + /// The caller must ensure that the storage array is compatible with the extension dtype. In + /// other words, they must know that `ext_dtype.validate_storage_array(&storage_array)` has been + /// called successfully on this storage array. + pub unsafe fn new_unchecked(ext_dtype: ExtDTypeRef, storage_array: ArrayRef) -> Self { + // TODO(connor): Replace these statements once we add `validate_storage_array`. + // #[cfg(debug_assertions)] + // ext_dtype + // .validate_storage_array(&storage_array) + // .vortex_expect("[Debug Assertion]: Invalid storage array for `ExtensionArray`"); + debug_assert_eq!( + ext_dtype.storage_dtype(), + storage_array.dtype(), + "ExtensionArray: storage_dtype must match storage array DType", + ); + Self { dtype: DType::Extension(ext_dtype), - storage, + storage_array, stats_set: ArrayStats::default(), } } + /// The extension dtype of this array. pub fn ext_dtype(&self) -> &ExtDTypeRef { let DType::Extension(ext) = &self.dtype else { unreachable!("ExtensionArray: dtype must be an ExtDType") }; - ext - } - pub fn storage(&self) -> &ArrayRef { - &self.storage + ext } - #[inline] - pub fn id(&self) -> ExtId { - self.ext_dtype().id() + pub fn storage_array(&self) -> &ArrayRef { + &self.storage_array } } diff --git a/vortex-array/src/arrays/extension/compute/cast.rs b/vortex-array/src/arrays/extension/compute/cast.rs index 86ee5a5a573..a5f6e0d2722 100644 --- a/vortex-array/src/arrays/extension/compute/cast.rs +++ b/vortex-array/src/arrays/extension/compute/cast.rs @@ -19,7 +19,10 @@ impl CastReduce for ExtensionVTable { unreachable!("Already verified we have an extension dtype"); }; - let new_storage = match array.storage().cast(ext_dtype.storage_dtype().clone()) { + let new_storage = match array + .storage_array() + .cast(ext_dtype.storage_dtype().clone()) + { Ok(arr) => arr, Err(e) => { tracing::warn!("Failed to cast storage array: {e}"); diff --git a/vortex-array/src/arrays/extension/compute/compare.rs b/vortex-array/src/arrays/extension/compute/compare.rs index 2f358b234ab..19855329b4b 100644 --- a/vortex-array/src/arrays/extension/compute/compare.rs +++ b/vortex-array/src/arrays/extension/compute/compare.rs @@ -26,7 +26,7 @@ impl CompareKernel for ExtensionVTable { if let Some(const_ext) = rhs.as_constant() { let storage_scalar = const_ext.as_extension().to_storage_scalar(); return lhs - .storage() + .storage_array() .to_array() .binary( ConstantArray::new(storage_scalar, lhs.len()).into_array(), @@ -38,9 +38,9 @@ impl CompareKernel for ExtensionVTable { // If the RHS is an extension array matching ours, we can extract the storage. if let Some(rhs_ext) = rhs.as_opt::() { return lhs - .storage() + .storage_array() .to_array() - .binary(rhs_ext.storage().to_array(), Operator::from(operator)) + .binary(rhs_ext.storage_array().to_array(), Operator::from(operator)) .map(Some); } diff --git a/vortex-array/src/arrays/extension/compute/filter.rs b/vortex-array/src/arrays/extension/compute/filter.rs index bd5b668ffe9..1939e39a78e 100644 --- a/vortex-array/src/arrays/extension/compute/filter.rs +++ b/vortex-array/src/arrays/extension/compute/filter.rs @@ -15,7 +15,7 @@ impl FilterReduce for ExtensionVTable { Ok(Some( ExtensionArray::new( array.ext_dtype().clone(), - array.storage().filter(mask.clone())?, + array.storage_array().filter(mask.clone())?, ) .into_array(), )) diff --git a/vortex-array/src/arrays/extension/compute/is_constant.rs b/vortex-array/src/arrays/extension/compute/is_constant.rs index d65bb34a978..de6f7841e33 100644 --- a/vortex-array/src/arrays/extension/compute/is_constant.rs +++ b/vortex-array/src/arrays/extension/compute/is_constant.rs @@ -17,7 +17,7 @@ impl IsConstantKernel for ExtensionVTable { array: &ExtensionArray, opts: &IsConstantOpts, ) -> VortexResult> { - compute::is_constant_opts(array.storage(), opts) + compute::is_constant_opts(array.storage_array(), opts) } } diff --git a/vortex-array/src/arrays/extension/compute/is_sorted.rs b/vortex-array/src/arrays/extension/compute/is_sorted.rs index e25850551ba..eeca175773d 100644 --- a/vortex-array/src/arrays/extension/compute/is_sorted.rs +++ b/vortex-array/src/arrays/extension/compute/is_sorted.rs @@ -12,11 +12,11 @@ use crate::register_kernel; impl IsSortedKernel for ExtensionVTable { fn is_sorted(&self, array: &ExtensionArray) -> VortexResult> { - compute::is_sorted(array.storage()) + compute::is_sorted(array.storage_array()) } fn is_strict_sorted(&self, array: &ExtensionArray) -> VortexResult> { - compute::is_strict_sorted(array.storage()) + compute::is_strict_sorted(array.storage_array()) } } diff --git a/vortex-array/src/arrays/extension/compute/mask.rs b/vortex-array/src/arrays/extension/compute/mask.rs index 6509264aefc..b72e8ab89e3 100644 --- a/vortex-array/src/arrays/extension/compute/mask.rs +++ b/vortex-array/src/arrays/extension/compute/mask.rs @@ -15,9 +15,9 @@ use crate::scalar_fn::fns::mask::MaskReduce; impl MaskReduce for ExtensionVTable { fn mask(array: &ExtensionArray, mask: &ArrayRef) -> VortexResult> { let masked_storage = MaskExpr.try_new_array( - array.storage().len(), + array.storage_array().len(), EmptyOptions, - [array.storage().clone(), mask.clone()], + [array.storage_array().clone(), mask.clone()], )?; Ok(Some( ExtensionArray::new( diff --git a/vortex-array/src/arrays/extension/compute/min_max.rs b/vortex-array/src/arrays/extension/compute/min_max.rs index 3835e67a810..ef39f516155 100644 --- a/vortex-array/src/arrays/extension/compute/min_max.rs +++ b/vortex-array/src/arrays/extension/compute/min_max.rs @@ -17,9 +17,11 @@ impl MinMaxKernel for ExtensionVTable { fn min_max(&self, array: &ExtensionArray) -> VortexResult> { let non_nullable_ext_dtype = array.ext_dtype().with_nullability(Nullability::NonNullable); Ok( - compute::min_max(array.storage())?.map(|MinMaxResult { min, max }| MinMaxResult { - min: Scalar::extension_ref(non_nullable_ext_dtype.clone(), min), - max: Scalar::extension_ref(non_nullable_ext_dtype, max), + compute::min_max(array.storage_array())?.map(|MinMaxResult { min, max }| { + MinMaxResult { + min: Scalar::extension_ref(non_nullable_ext_dtype.clone(), min), + max: Scalar::extension_ref(non_nullable_ext_dtype, max), + } }), ) } diff --git a/vortex-array/src/arrays/extension/compute/rules.rs b/vortex-array/src/arrays/extension/compute/rules.rs index ebdedd55537..45a3f31f1e7 100644 --- a/vortex-array/src/arrays/extension/compute/rules.rs +++ b/vortex-array/src/arrays/extension/compute/rules.rs @@ -39,7 +39,7 @@ impl ArrayParentReduceRule for ExtensionFilterPushDownRule { ) -> VortexResult> { debug_assert_eq!(child_idx, 0); let filtered_storage = child - .storage() + .storage_array() .clone() .filter(parent.filter_mask().clone())?; Ok(Some( @@ -95,18 +95,13 @@ mod tests { Ok(EmptyMetadata) } - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, _extension_dtype: &ExtDType) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _extension_dtype: &'a ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") @@ -147,7 +142,7 @@ mod tests { assert_eq!(ext_result.ext_dtype(), &ext_dtype); // Check the storage values - let storage_result: &[i64] = &ext_result.storage().to_primitive().to_buffer::(); + let storage_result: &[i64] = &ext_result.storage_array().to_primitive().to_buffer::(); assert_eq!(storage_result, &[1, 3, 5]); } @@ -173,7 +168,7 @@ mod tests { assert_eq!(ext_result.len(), 3); // Check values: should be [Some(1), None, None] - let canonical = ext_result.storage().to_primitive(); + let canonical = ext_result.storage_array().to_primitive(); assert_eq!(canonical.len(), 3); } @@ -197,18 +192,13 @@ mod tests { Ok(EmptyMetadata) } - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, _extension_dtype: &ExtDType) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _extension_dtype: &'a ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") diff --git a/vortex-array/src/arrays/extension/compute/slice.rs b/vortex-array/src/arrays/extension/compute/slice.rs index 5b85e50dd8b..87c6fef228e 100644 --- a/vortex-array/src/arrays/extension/compute/slice.rs +++ b/vortex-array/src/arrays/extension/compute/slice.rs @@ -14,8 +14,11 @@ use crate::arrays::slice::SliceReduce; impl SliceReduce for ExtensionVTable { fn slice(array: &Self::Array, range: Range) -> VortexResult> { Ok(Some( - ExtensionArray::new(array.ext_dtype().clone(), array.storage().slice(range)?) - .into_array(), + ExtensionArray::new( + array.ext_dtype().clone(), + array.storage_array().slice(range)?, + ) + .into_array(), )) } } diff --git a/vortex-array/src/arrays/extension/compute/sum.rs b/vortex-array/src/arrays/extension/compute/sum.rs index a0b8a72bc06..89798e510da 100644 --- a/vortex-array/src/arrays/extension/compute/sum.rs +++ b/vortex-array/src/arrays/extension/compute/sum.rs @@ -13,7 +13,7 @@ use crate::scalar::Scalar; impl SumKernel for ExtensionVTable { fn sum(&self, array: &ExtensionArray, accumulator: &Scalar) -> VortexResult { - compute::sum_with_accumulator(array.storage(), accumulator) + compute::sum_with_accumulator(array.storage_array(), accumulator) } } diff --git a/vortex-array/src/arrays/extension/compute/take.rs b/vortex-array/src/arrays/extension/compute/take.rs index dbf54fea41f..f223251e1b3 100644 --- a/vortex-array/src/arrays/extension/compute/take.rs +++ b/vortex-array/src/arrays/extension/compute/take.rs @@ -17,7 +17,7 @@ impl TakeExecute for ExtensionVTable { indices: &ArrayRef, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - let taken_storage = array.storage().take(indices.to_array())?; + let taken_storage = array.storage_array().take(indices.to_array())?; Ok(Some( ExtensionArray::new( array diff --git a/vortex-array/src/arrays/extension/vtable/mod.rs b/vortex-array/src/arrays/extension/vtable/mod.rs index e99db442573..78b1b9f32a8 100644 --- a/vortex-array/src/arrays/extension/vtable/mod.rs +++ b/vortex-array/src/arrays/extension/vtable/mod.rs @@ -48,7 +48,7 @@ impl VTable for ExtensionVTable { } fn len(array: &ExtensionArray) -> usize { - array.storage.len() + array.storage_array.len() } fn dtype(array: &ExtensionArray) -> &DType { @@ -65,11 +65,14 @@ impl VTable for ExtensionVTable { precision: Precision, ) { array.dtype.hash(state); - array.storage.array_hash(state, precision); + array.storage_array.array_hash(state, precision); } fn array_eq(array: &ExtensionArray, other: &ExtensionArray, precision: Precision) -> bool { - array.dtype == other.dtype && array.storage.array_eq(&other.storage, precision) + array.dtype == other.dtype + && array + .storage_array + .array_eq(&other.storage_array, precision) } fn nbuffers(_array: &ExtensionArray) -> usize { @@ -90,7 +93,7 @@ impl VTable for ExtensionVTable { fn child(array: &ExtensionArray, idx: usize) -> ArrayRef { match idx { - 0 => array.storage.clone(), + 0 => array.storage_array.clone(), _ => vortex_panic!("ExtensionArray child index {idx} out of bounds"), } } @@ -143,7 +146,7 @@ impl VTable for ExtensionVTable { "ExtensionArray expects exactly 1 child (storage), got {}", children.len() ); - array.storage = children + array.storage_array = children .into_iter() .next() .vortex_expect("children length already validated"); diff --git a/vortex-array/src/arrays/extension/vtable/operations.rs b/vortex-array/src/arrays/extension/vtable/operations.rs index 019e31ec77d..1b1c27660e8 100644 --- a/vortex-array/src/arrays/extension/vtable/operations.rs +++ b/vortex-array/src/arrays/extension/vtable/operations.rs @@ -13,7 +13,7 @@ impl OperationsVTable for ExtensionVTable { fn scalar_at(array: &ExtensionArray, index: usize) -> VortexResult { Ok(Scalar::extension_ref( array.ext_dtype().clone(), - array.storage().scalar_at(index)?, + array.storage_array().scalar_at(index)?, )) } } diff --git a/vortex-array/src/arrays/extension/vtable/validity.rs b/vortex-array/src/arrays/extension/vtable/validity.rs index 98269f854b3..81400c393a9 100644 --- a/vortex-array/src/arrays/extension/vtable/validity.rs +++ b/vortex-array/src/arrays/extension/vtable/validity.rs @@ -8,6 +8,6 @@ use crate::vtable::ValidityChild; impl ValidityChild for ExtensionVTable { fn validity_child(array: &ExtensionArray) -> &ArrayRef { - &array.storage + &array.storage_array } } diff --git a/vortex-array/src/arrays/filter/execute/mod.rs b/vortex-array/src/arrays/filter/execute/mod.rs index 0863e6f0cc7..da9b059d8ed 100644 --- a/vortex-array/src/arrays/filter/execute/mod.rs +++ b/vortex-array/src/arrays/filter/execute/mod.rs @@ -89,7 +89,7 @@ pub(super) fn execute_filter(canonical: Canonical, mask: &Arc) -> Ca Canonical::Struct(a) => Canonical::Struct(struct_::filter_struct(&a, mask)), Canonical::Extension(a) => { let filtered_storage = a - .storage() + .storage_array() .filter(values_to_mask(mask)) .vortex_expect("ExtensionArray storage type somehow could not be filtered"); Canonical::Extension(ExtensionArray::new(a.ext_dtype().clone(), filtered_storage)) diff --git a/vortex-array/src/arrays/listview/conversion.rs b/vortex-array/src/arrays/listview/conversion.rs index ca63f3521ab..b942cffcc0c 100644 --- a/vortex-array/src/arrays/listview/conversion.rs +++ b/vortex-array/src/arrays/listview/conversion.rs @@ -261,10 +261,11 @@ pub fn recursive_list_from_list_view(array: ArrayRef) -> VortexResult } } Canonical::Extension(ext_array) => { - let converted_storage = recursive_list_from_list_view(ext_array.storage().clone())?; + let converted_storage = + recursive_list_from_list_view(ext_array.storage_array().clone())?; // Avoid cloning if elements didn't change. - if !Arc::ptr_eq(&converted_storage, ext_array.storage()) { + if !Arc::ptr_eq(&converted_storage, ext_array.storage_array()) { ExtensionArray::new(ext_array.ext_dtype().clone(), converted_storage).into_array() } else { ext_array.into_array() diff --git a/vortex-array/src/arrays/masked/execute.rs b/vortex-array/src/arrays/masked/execute.rs index c6c0fc64348..a4cf0f06f7c 100644 --- a/vortex-array/src/arrays/masked/execute.rs +++ b/vortex-array/src/arrays/masked/execute.rs @@ -151,7 +151,7 @@ fn mask_validity_extension( ctx: &mut ExecutionCtx, ) -> VortexResult { // For extension arrays, we need to mask the underlying storage - let storage = array.storage().clone().execute::(ctx)?; + let storage = array.storage_array().clone().execute::(ctx)?; let masked_storage = mask_validity_canonical(storage, mask, ctx)?; let masked_storage = masked_storage.into_array(); Ok(ExtensionArray::new( diff --git a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs index c1dca2b3168..6913996cf16 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs @@ -10,7 +10,6 @@ use std::marker::PhantomData; use std::ops::Deref; use itertools::Itertools; -use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; @@ -273,18 +272,11 @@ impl Matcher for ExactScalarFn { fn try_match(array: &dyn DynArray) -> Option> { let scalar_fn_array = array.as_opt::()?; - let scalar_fn_vtable = scalar_fn_array - .scalar_fn - .vtable_ref::() - .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher"); - let scalar_fn_options = scalar_fn_array - .scalar_fn - .as_opt::() - .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher"); + let scalar_fn = scalar_fn_array.scalar_fn.downcast_ref::()?; Some(ScalarFnArrayView { array, - vtable: scalar_fn_vtable, - options: scalar_fn_options, + vtable: scalar_fn.vtable(), + options: scalar_fn.options(), }) } } diff --git a/vortex-array/src/arrow/executor/temporal.rs b/vortex-array/src/arrow/executor/temporal.rs index 6e5a6c71a90..a9316f9568e 100644 --- a/vortex-array/src/arrow/executor/temporal.rs +++ b/vortex-array/src/arrow/executor/temporal.rs @@ -147,7 +147,7 @@ where let ext_array = array.execute::(ctx)?; let primitive = ext_array - .storage() + .storage_array() .clone() .execute::(ctx)?; vortex_ensure!( diff --git a/vortex-array/src/builders/extension.rs b/vortex-array/src/builders/extension.rs index 85391c84200..4875cd39182 100644 --- a/vortex-array/src/builders/extension.rs +++ b/vortex-array/src/builders/extension.rs @@ -99,7 +99,7 @@ impl ArrayBuilder for ExtensionBuilder { unsafe fn extend_from_array_unchecked(&mut self, array: &ArrayRef) { let ext_array = array.to_extension(); - self.storage.extend_from_array(ext_array.storage()) + self.storage.extend_from_array(ext_array.storage_array()) } fn reserve_exact(&mut self, capacity: usize) { diff --git a/vortex-array/src/canonical.rs b/vortex-array/src/canonical.rs index 4aaee7d88f8..bff1f64fc37 100644 --- a/vortex-array/src/canonical.rs +++ b/vortex-array/src/canonical.rs @@ -630,7 +630,7 @@ impl Executable for CanonicalValidity { Canonical::Extension(ext) => Ok(CanonicalValidity(Canonical::Extension( ExtensionArray::new( ext.ext_dtype().clone(), - ext.storage() + ext.storage_array() .clone() .execute::(ctx)? .0 @@ -760,7 +760,7 @@ impl Executable for RecursiveCanonical { Canonical::Extension(ext) => Ok(RecursiveCanonical(Canonical::Extension( ExtensionArray::new( ext.ext_dtype().clone(), - ext.storage() + ext.storage_array() .clone() .execute::(ctx)? .0 diff --git a/vortex-array/src/display/mod.rs b/vortex-array/src/display/mod.rs index f47b3926585..9c34d6642b1 100644 --- a/vortex-array/src/display/mod.rs +++ b/vortex-array/src/display/mod.rs @@ -330,6 +330,7 @@ impl Display for dyn DynArray + '_ { } } +const DISPLAY_LIMIT: usize = 16; impl dyn DynArray + '_ { /// Display logical values of the array /// @@ -476,18 +477,20 @@ impl dyn DynArray + '_ { DisplayOptions::CommaSeparatedScalars { omit_comma_after_space, } => { - write!(f, "[")?; + write!(f, "{}", if f.alternate() { "[\n" } else { "[" })?; let sep = if *omit_comma_after_space { "," } else { ", " }; + let sep = if f.alternate() { ",\n" } else { sep }; + let limit = std::cmp::min(self.len(), f.precision().unwrap_or(DISPLAY_LIMIT)); write!( f, "{}", - (0..self.len()) + (0..limit) .map(|i| self .scalar_at(i) .map_or_else(|e| format!(""), |s| s.to_string())) .format(sep) )?; - write!(f, "]") + write!(f, "{}", if f.alternate() { "\n]" } else { "]" }) } DisplayOptions::TreeDisplay { buffers, diff --git a/vortex-array/src/dtype/extension/erased.rs b/vortex-array/src/dtype/extension/erased.rs index fd2b1814c35..3f58cfadc49 100644 --- a/vortex-array/src/dtype/extension/erased.rs +++ b/vortex-array/src/dtype/extension/erased.rs @@ -20,7 +20,6 @@ use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::dtype::extension::Matcher; use crate::dtype::extension::typed::DynExtDType; -use crate::dtype::extension::typed::ExtDTypeInner; use crate::scalar::ScalarValue; /// A type-erased extension dtype. @@ -40,12 +39,12 @@ pub struct ExtDTypeRef(pub(super) Arc); impl ExtDTypeRef { /// Returns the [`ExtId`] identifying this extension type. pub fn id(&self) -> ExtId { - self.0.id() + self.0.ext_id() } /// Returns the storage dtype of the extension type. pub fn storage_dtype(&self) -> &DType { - self.0.storage_dtype() + self.0.ext_storage_dtype() } /// Returns the nullability of the storage dtype. @@ -127,12 +126,10 @@ impl ExtDTypeRef { /// Downcast to the concrete [`ExtDType`]. /// /// Returns `Err(self)` if the downcast fails. - pub fn try_downcast(self) -> Result, ExtDTypeRef> { - if self.0.as_any().is::>() { - // SAFETY: type matches and ExtDTypeInner is the only implementor - let ptr = Arc::into_raw(self.0) as *const ExtDTypeInner; - let inner = unsafe { Arc::from_raw(ptr) }; - Ok(ExtDType(inner)) + pub fn try_downcast(self) -> Result>, ExtDTypeRef> { + if self.0.as_any().is::>() { + let ptr = Arc::into_raw(self.0) as *const ExtDType; + Ok(unsafe { Arc::from_raw(ptr) }) } else { Err(self) } @@ -143,12 +140,12 @@ impl ExtDTypeRef { /// # Panics /// /// Panics if the downcast fails. - pub fn downcast(self) -> ExtDType { + pub fn downcast(self) -> Arc> { self.try_downcast::() .map_err(|this| { vortex_err!( "Failed to downcast ExtDTypeRef {} to {}", - this.0.id(), + this.0.ext_id(), type_name::(), ) }) diff --git a/vortex-array/src/dtype/extension/matcher.rs b/vortex-array/src/dtype/extension/matcher.rs index 121f62d7aee..88e92f90233 100644 --- a/vortex-array/src/dtype/extension/matcher.rs +++ b/vortex-array/src/dtype/extension/matcher.rs @@ -1,9 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtDTypeRef; use crate::dtype::extension::ExtVTable; -use crate::dtype::extension::typed::ExtDTypeInner; /// A trait for matching extension dtypes. pub trait Matcher { @@ -23,13 +23,13 @@ impl Matcher for V { type Match<'a> = &'a V::Metadata; fn matches(item: &ExtDTypeRef) -> bool { - item.0.as_any().is::>() + item.0.as_any().is::>() } fn try_match<'a>(item: &'a ExtDTypeRef) -> Option> { item.0 .as_any() - .downcast_ref::>() - .map(|inner| &inner.metadata) + .downcast_ref::>() + .map(|inner| inner.metadata()) } } diff --git a/vortex-array/src/dtype/extension/mod.rs b/vortex-array/src/dtype/extension/mod.rs index 2e818412e6d..c6ecc86bbf7 100644 --- a/vortex-array/src/dtype/extension/mod.rs +++ b/vortex-array/src/dtype/extension/mod.rs @@ -34,11 +34,11 @@ pub type ExtId = arcref::ArcRef; /// Private module to seal [`typed::DynExtDType`]. mod sealed { use crate::dtype::extension::ExtVTable; - use crate::dtype::extension::typed::ExtDTypeInner; + use crate::dtype::extension::typed::ExtDType; /// Marker trait to prevent external implementations of [`super::typed::DynExtDType`]. pub(crate) trait Sealed {} /// This can be the **only** implementor for [`super::typed::DynExtDType`]. - impl Sealed for ExtDTypeInner {} + impl Sealed for ExtDType {} } diff --git a/vortex-array/src/dtype/extension/typed.rs b/vortex-array/src/dtype/extension/typed.rs index 4cb1aff624c..de843249727 100644 --- a/vortex-array/src/dtype/extension/typed.rs +++ b/vortex-array/src/dtype/extension/typed.rs @@ -32,7 +32,14 @@ use crate::scalar::ScalarValue; /// [`try_with_vtable()`]: ExtDType::try_with_vtable /// [`erased()`]: ExtDType::erased #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ExtDType(pub(super) Arc>); +pub struct ExtDType { + /// The extension dtype vtable. + vtable: V, + /// The extension dtype metadata. + metadata: V::Metadata, + /// The underlying storage dtype. + storage_dtype: DType, +} /// Convenience implementation for zero-sized VTables (or VTables that implement `Default`). impl ExtDType { @@ -49,60 +56,43 @@ impl ExtDType { metadata: V::Metadata, storage_dtype: DType, ) -> VortexResult { - vtable.validate_dtype(&metadata, &storage_dtype)?; - - Ok(Self(Arc::new(ExtDTypeInner:: { + let this = Self { vtable, metadata, storage_dtype, - }))) + }; + + this.vtable.validate_dtype(&this)?; + + Ok(this) } /// Returns the identifier of the extension type. pub fn id(&self) -> ExtId { - self.0.vtable.id() + self.vtable.id() } /// Returns the vtable of the extension type. pub fn vtable(&self) -> &V { - &self.0.vtable + &self.vtable } /// Returns the metadata of the extension type. pub fn metadata(&self) -> &V::Metadata { - &self.0.metadata + &self.metadata } /// Returns the storage dtype of the extension type. pub fn storage_dtype(&self) -> &DType { - &self.0.storage_dtype + &self.storage_dtype } /// Erase the concrete type information, returning a type-erased extension dtype. pub fn erased(self) -> ExtDTypeRef { - ExtDTypeRef(self.0) + ExtDTypeRef(Arc::new(self)) } } -// --------------------------------------------------------------------------- -// Private inner struct + sealed trait -// --------------------------------------------------------------------------- - -/// The private inner representation of an extension dtype, pairing a vtable with its metadata -/// and storage dtype. -/// -/// This is the sole implementor of [`DynExtDType`], enabling [`ExtDTypeRef`] to safely downcast -/// back to the concrete vtable type via [`Any`]. -#[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ExtDTypeInner { - /// The extension dtype vtable. - pub(super) vtable: V, - /// The extension dtype metadata. - pub(super) metadata: V::Metadata, - /// The underlying storage dtype. - pub(super) storage_dtype: DType, -} - /// An object-safe, sealed trait encapsulating the behavior for extension dtypes. /// /// This provides type-erased access to the extension dtype's identity, storage dtype, and @@ -111,9 +101,9 @@ pub(super) trait DynExtDType: 'static + Send + Sync + super::sealed::Sealed { /// Returns `self` as a trait object for downcasting. fn as_any(&self) -> &dyn Any; /// Returns the [`ExtId`] identifying this extension type. - fn id(&self) -> ExtId; + fn ext_id(&self) -> ExtId; /// Returns a reference to the storage [`DType`]. - fn storage_dtype(&self) -> &DType; + fn ext_storage_dtype(&self) -> &DType; /// Returns the metadata as a trait object for downcasting. fn metadata_any(&self) -> &dyn Any; /// Formats the metadata using [`Debug`]. @@ -135,16 +125,16 @@ pub(super) trait DynExtDType: 'static + Send + Sync + super::sealed::Sealed { -> fmt::Result; } -impl DynExtDType for ExtDTypeInner { +impl DynExtDType for ExtDType { fn as_any(&self) -> &dyn Any { self } - fn id(&self) -> ExtId { + fn ext_id(&self) -> ExtId { self.vtable.id() } - fn storage_dtype(&self) -> &DType { + fn ext_storage_dtype(&self) -> &DType { &self.storage_dtype } @@ -186,8 +176,7 @@ impl DynExtDType for ExtDTypeInner { } fn value_validate(&self, storage_value: &ScalarValue) -> VortexResult<()> { - self.vtable - .validate_scalar_value(&self.metadata, &self.storage_dtype, storage_value) + self.vtable.validate_scalar_value(self, storage_value) } fn value_display( @@ -195,10 +184,7 @@ impl DynExtDType for ExtDTypeInner { f: &mut fmt::Formatter<'_>, storage_value: &ScalarValue, ) -> fmt::Result { - match self - .vtable - .unpack_native(&self.metadata, &self.storage_dtype, storage_value) - { + match self.vtable.unpack_native(self, storage_value) { Ok(native) => fmt::Display::fmt(&native, f), Err(_) => write!( f, diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index 84f530ea06f..8d324521b41 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -7,7 +7,7 @@ use std::hash::Hash; use vortex_error::VortexResult; -use crate::dtype::DType; +use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; use crate::scalar::ScalarValue; @@ -36,7 +36,7 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult; /// Validate that the given storage type is compatible with this extension type. - fn validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()>; + fn validate_dtype(&self, ext_dtype: &ExtDType) -> VortexResult<()>; // Methods related to the extension scalar values. @@ -49,27 +49,24 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Returns an error if the storage [`ScalarValue`] is not compatible with the extension type. fn validate_scalar_value( &self, - metadata: &Self::Metadata, - storage_dtype: &DType, + ext_dtype: &ExtDType, storage_value: &ScalarValue, ) -> VortexResult<()> { - self.unpack_native(metadata, storage_dtype, storage_value) - .map(|_| ()) + self.unpack_native(ext_dtype, storage_value).map(|_| ()) } /// Validate and unpack a native value from the storage [`ScalarValue`]. /// /// Note that [`ExtVTable::validate_dtype()`] is always called first to validate the storage - /// [`DType`], and the [`Scalar`](crate::scalar::Scalar) implementation will verify that the - /// storage value is compatible with the storage dtype on construction. + /// [`crate::dtype::DType`], and the [`Scalar`](crate::scalar::Scalar) implementation will + /// verify that the storage value is compatible with the storage dtype on construction. /// /// # Errors /// /// Returns an error if the storage [`ScalarValue`] is not compatible with the extension type. fn unpack_native<'a>( &self, - metadata: &'a Self::Metadata, - storage_dtype: &'a DType, + ext_dtype: &'a ExtDType, storage_value: &'a ScalarValue, ) -> VortexResult>; } diff --git a/vortex-array/src/extension/datetime/date.rs b/vortex-array/src/extension/datetime/date.rs index 802bd8cf6a2..6d67051721b 100644 --- a/vortex-array/src/extension/datetime/date.rs +++ b/vortex-array/src/extension/datetime/date.rs @@ -91,12 +91,13 @@ impl ExtVTable for Date { TimeUnit::try_from(tag) } - fn validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()> { + fn validate_dtype(&self, ext_dtype: &ExtDType) -> VortexResult<()> { + let metadata = ext_dtype.metadata(); let ptype = date_ptype(metadata) .ok_or_else(|| vortex_err!("Date type does not support time unit {}", metadata))?; vortex_ensure!( - storage_dtype.as_ptype() == ptype, + ext_dtype.storage_dtype().as_ptype() == ptype, "Date storage dtype for {} must be {}", metadata, ptype @@ -107,10 +108,10 @@ impl ExtVTable for Date { fn unpack_native( &self, - metadata: &Self::Metadata, - _storage_dtype: &DType, + ext_dtype: &ExtDType, storage_value: &ScalarValue, ) -> VortexResult> { + let metadata = ext_dtype.metadata(); match metadata { TimeUnit::Milliseconds => Ok(DateValue::Milliseconds( storage_value.as_primitive().cast::()?, diff --git a/vortex-array/src/extension/datetime/time.rs b/vortex-array/src/extension/datetime/time.rs index 77982f5ee8e..59e7c98bc7e 100644 --- a/vortex-array/src/extension/datetime/time.rs +++ b/vortex-array/src/extension/datetime/time.rs @@ -92,12 +92,13 @@ impl ExtVTable for Time { TimeUnit::try_from(tag) } - fn validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()> { + fn validate_dtype(&self, ext_dtype: &ExtDType) -> VortexResult<()> { + let metadata = ext_dtype.metadata(); let ptype = time_ptype(metadata) .ok_or_else(|| vortex_err!("Time type does not support time unit {}", metadata))?; vortex_ensure!( - storage_dtype.as_ptype() == ptype, + ext_dtype.storage_dtype().as_ptype() == ptype, "Time storage dtype for {} must be {}", metadata, ptype @@ -108,13 +109,12 @@ impl ExtVTable for Time { fn unpack_native( &self, - metadata: &Self::Metadata, - _storage_dtype: &DType, + ext_dtype: &ExtDType, storage_value: &ScalarValue, ) -> VortexResult> { let length_of_time = storage_value.as_primitive().cast::()?; - let (span, value) = match *metadata { + let (span, value) = match *ext_dtype.metadata() { TimeUnit::Seconds => { let v = i32::try_from(length_of_time) .map_err(|e| vortex_err!("Time seconds value out of i32 range: {e}"))?; diff --git a/vortex-array/src/extension/datetime/timestamp.rs b/vortex-array/src/extension/datetime/timestamp.rs index d65f6b24cce..d314cef60a3 100644 --- a/vortex-array/src/extension/datetime/timestamp.rs +++ b/vortex-array/src/extension/datetime/timestamp.rs @@ -169,13 +169,9 @@ impl ExtVTable for Timestamp { }) } - fn validate_dtype( - &self, - _metadata: &Self::Metadata, - storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, ext_dtype: &ExtDType) -> VortexResult<()> { vortex_ensure!( - matches!(storage_dtype, DType::Primitive(PType::I64, _)), + matches!(ext_dtype.storage_dtype(), DType::Primitive(PType::I64, _)), "Timestamp storage dtype must be i64" ); Ok(()) @@ -183,10 +179,10 @@ impl ExtVTable for Timestamp { fn unpack_native<'a>( &self, - metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + ext_dtype: &'a ExtDType, storage_value: &'a ScalarValue, ) -> VortexResult> { + let metadata = ext_dtype.metadata(); let ts_value = storage_value.as_primitive().cast::()?; let tz = metadata.tz.as_ref(); diff --git a/vortex-array/src/extension/mod.rs b/vortex-array/src/extension/mod.rs index 5c29154ded0..9f81e7fb310 100644 --- a/vortex-array/src/extension/mod.rs +++ b/vortex-array/src/extension/mod.rs @@ -6,6 +6,7 @@ use std::fmt; pub mod datetime; +pub mod uuid; #[cfg(test)] mod tests; diff --git a/vortex-array/src/extension/tests/divisible_int.rs b/vortex-array/src/extension/tests/divisible_int.rs index c34fcea3df6..7b6e501b7cb 100644 --- a/vortex-array/src/extension/tests/divisible_int.rs +++ b/vortex-array/src/extension/tests/divisible_int.rs @@ -16,7 +16,7 @@ use crate::dtype::extension::ExtVTable; use crate::scalar::ScalarValue; /// The divisor stored as extension metadata. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct Divisor(pub u64); impl fmt::Display for Divisor { @@ -53,23 +53,22 @@ impl ExtVTable for DivisibleInt { fn validate_dtype( &self, - _metadata: &Self::Metadata, - storage_dtype: &DType, + ext_dtype: &crate::dtype::extension::ExtDType, ) -> VortexResult<()> { vortex_ensure!( - matches!(storage_dtype, DType::Primitive(PType::U64, _)), + matches!(ext_dtype.storage_dtype(), DType::Primitive(PType::U64, _)), "divisible int storage dtype must be u64" ); Ok(()) } - fn unpack_native( + fn unpack_native<'a>( &self, - metadata: &Self::Metadata, - _storage_dtype: &DType, - storage_value: &ScalarValue, - ) -> VortexResult> { + ext_dtype: &'a crate::dtype::extension::ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { let value = storage_value.as_primitive().cast::()?; + let metadata = ext_dtype.metadata(); if value % metadata.0 != 0 { vortex_bail!("{} is not divisible by {}", value, metadata.0); } @@ -86,6 +85,7 @@ mod tests { use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; + use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtVTable; #[test] @@ -109,29 +109,25 @@ mod tests { #[test] fn rejects_wrong_storage_dtype() { - let vtable = DivisibleInt; let divisor = Divisor(10); assert!( - vtable - .validate_dtype( - &divisor, - &DType::Primitive(PType::I32, Nullability::NonNullable) - ) - .is_err() + ExtDType::::try_new( + divisor, + DType::Primitive(PType::I32, Nullability::NonNullable) + ) + .is_err() ); assert!( - vtable - .validate_dtype(&divisor, &DType::Utf8(Nullability::NonNullable)) + ExtDType::::try_new(divisor, DType::Utf8(Nullability::NonNullable)) .is_err() ); assert!( - vtable - .validate_dtype( - &divisor, - &DType::Primitive(PType::U64, Nullability::NonNullable) - ) - .is_ok() + ExtDType::::try_new( + divisor, + DType::Primitive(PType::U64, Nullability::NonNullable) + ) + .is_ok() ); } } diff --git a/vortex-array/src/extension/uuid/metadata.rs b/vortex-array/src/extension/uuid/metadata.rs new file mode 100644 index 00000000000..7e7dd8b16d8 --- /dev/null +++ b/vortex-array/src/extension/uuid/metadata.rs @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt; +use std::hash::Hash; +use std::hash::Hasher; + +use uuid::Version; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +/// Converts a `u8` discriminant back to a [`uuid::Version`]. +pub(crate) fn u8_to_version(b: u8) -> VortexResult { + match b { + 0 => Ok(Version::Nil), + 1 => Ok(Version::Mac), + 2 => Ok(Version::Dce), + 3 => Ok(Version::Md5), + 4 => Ok(Version::Random), + 5 => Ok(Version::Sha1), + 6 => Ok(Version::SortMac), + 7 => Ok(Version::SortRand), + 8 => Ok(Version::Custom), + 0xff => Ok(Version::Max), + _ => vortex_bail!("unknown UUID version discriminant: {b}"), + } +} + +/// Metadata for the UUID extension type. +/// +/// Optionally records which UUID version the column contains (e.g. v4 random, v7 +/// sort-random). When `None`, the column may contain any mix of versions. +#[derive(Clone, Debug, Default)] +pub struct UuidMetadata { + /// The UUID version, if known. + pub version: Option, +} + +impl fmt::Display for UuidMetadata { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.version { + None => write!(f, ""), + Some(v) => write!(f, "v{}", v as u8), + } + } +} + +// `uuid::Version` derives `PartialEq` but not `Eq` or `Hash`, so we implement these +// manually using the `#[repr(u8)]` discriminant. + +impl PartialEq for UuidMetadata { + fn eq(&self, other: &Self) -> bool { + self.version.map(|v| v as u8) == other.version.map(|v| v as u8) + } +} + +impl Eq for UuidMetadata {} + +impl Hash for UuidMetadata { + fn hash(&self, state: &mut H) { + self.version.map(|v| v as u8).hash(state); + } +} diff --git a/vortex-array/src/extension/uuid/mod.rs b/vortex-array/src/extension/uuid/mod.rs new file mode 100644 index 00000000000..e4347c2513c --- /dev/null +++ b/vortex-array/src/extension/uuid/mod.rs @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! UUID extension type for Vortex. +//! +//! Provides a UUID extension type backed by `FixedSizeList(Primitive(U8), 16)` storage. Each UUID +//! is stored as 16 bytes in big-endian (network) byte order, matching [RFC 4122] and Arrow's +//! [canonical UUID extension]. +//! +//! [RFC 4122]: https://www.rfc-editor.org/rfc/rfc4122 +//! [canonical UUID extension]: https://arrow.apache.org/docs/format/CanonicalExtensions.html#uuid + +mod metadata; +pub use metadata::UuidMetadata; + +pub(crate) mod vtable; + +/// The VTable for the UUID extension type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct Uuid; diff --git a/vortex-array/src/extension/uuid/vtable.rs b/vortex-array/src/extension/uuid/vtable.rs new file mode 100644 index 00000000000..962b83f8907 --- /dev/null +++ b/vortex-array/src/extension/uuid/vtable.rs @@ -0,0 +1,350 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use uuid; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; + +use crate::dtype::DType; +use crate::dtype::PType; +use crate::dtype::extension::ExtDType; +use crate::dtype::extension::ExtId; +use crate::dtype::extension::ExtVTable; +use crate::extension::uuid::Uuid; +use crate::extension::uuid::UuidMetadata; +use crate::extension::uuid::metadata::u8_to_version; +use crate::scalar::PValue; +use crate::scalar::ScalarValue; + +/// The number of bytes in a UUID. +pub(crate) const UUID_BYTE_LEN: usize = 16; + +impl ExtVTable for Uuid { + type Metadata = UuidMetadata; + type NativeValue<'a> = uuid::Uuid; + + fn id(&self) -> ExtId { + ExtId::new_ref("vortex.uuid") + } + + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { + match metadata.version { + None => Ok(Vec::new()), + Some(v) => Ok(vec![v as u8]), + } + } + + fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { + let version = match metadata.len() { + 0 => None, + 1 => Some(u8_to_version(metadata[0])?), + other => vortex_bail!("UUID metadata must be 0 or 1 bytes, got {other}"), + }; + + Ok(UuidMetadata { version }) + } + + fn validate_dtype(&self, ext_dtype: &ExtDType) -> VortexResult<()> { + let storage_dtype = ext_dtype.storage_dtype(); + let DType::FixedSizeList(element_dtype, list_size, _nullability) = storage_dtype else { + vortex_bail!("UUID storage dtype must be a FixedSizeList, got {storage_dtype}"); + }; + + vortex_ensure_eq!( + *list_size as usize, + UUID_BYTE_LEN, + "UUID storage FixedSizeList must have size {UUID_BYTE_LEN}, got {list_size}" + ); + + let DType::Primitive(ptype, elem_nullability) = element_dtype.as_ref() else { + vortex_bail!("UUID element dtype must be Primitive(U8), got {element_dtype}"); + }; + + vortex_ensure_eq!( + *ptype, + PType::U8, + "UUID element dtype must be U8, got {ptype}" + ); + vortex_ensure!( + !elem_nullability.is_nullable(), + "UUID element dtype must be non-nullable" + ); + + Ok(()) + } + + fn unpack_native<'a>( + &self, + ext_dtype: &ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + let elements = storage_value.as_list(); + vortex_ensure_eq!( + elements.len(), + UUID_BYTE_LEN, + "UUID scalar must have exactly {UUID_BYTE_LEN} bytes, got {}", + elements.len() + ); + + let mut bytes = [0u8; UUID_BYTE_LEN]; + for (i, elem) in elements.iter().enumerate() { + let Some(scalar_value) = elem else { + vortex_bail!("UUID byte at index {i} must not be null"); + }; + let PValue::U8(b) = scalar_value.as_primitive() else { + vortex_bail!("UUID byte at index {i} must be U8"); + }; + bytes[i] = *b; + } + + let parsed = uuid::Uuid::from_bytes(bytes); + + // Verify the parsed UUID matches the expected version, if one is set. + if let Some(expected) = ext_dtype.metadata().version { + let expected = expected as u8; + let actual = parsed + .get_version() + .ok_or_else(|| vortex_err!("UUID has unrecognized version nibble"))? + as u8; + + vortex_ensure_eq!( + expected, + actual, + "UUID version mismatch: expected v{expected}, got v{actual}", + ); + } + + Ok(parsed) + } +} + +#[expect( + clippy::cast_possible_truncation, + reason = "UUID_BYTE_LEN always fits both usize and u32" +)] +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rstest::rstest; + use uuid::Version; + use vortex_error::VortexResult; + + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::dtype::extension::ExtDType; + use crate::dtype::extension::ExtVTable; + use crate::extension::uuid::Uuid; + use crate::extension::uuid::UuidMetadata; + use crate::extension::uuid::vtable::UUID_BYTE_LEN; + use crate::scalar::Scalar; + use crate::scalar::ScalarValue; + + #[rstest] + #[case::no_version(None)] + #[case::v4_random(Some(Version::Random))] + #[case::v7_sort_rand(Some(Version::SortRand))] + #[case::nil(Some(Version::Nil))] + #[case::max(Some(Version::Max))] + fn roundtrip_metadata(#[case] version: Option) -> VortexResult<()> { + let metadata = UuidMetadata { version }; + let bytes = Uuid.serialize_metadata(&metadata)?; + let expected_len = if version.is_none() { 0 } else { 1 }; + assert_eq!(bytes.len(), expected_len); + let deserialized = Uuid.deserialize_metadata(&bytes)?; + assert_eq!(deserialized, metadata); + Ok(()) + } + + #[test] + fn metadata_display_no_version() { + let metadata = UuidMetadata { version: None }; + assert_eq!(metadata.to_string(), ""); + } + + #[test] + fn metadata_display_with_version() { + let metadata = UuidMetadata { + version: Some(Version::Random), + }; + assert_eq!(metadata.to_string(), "v4"); + + let metadata = UuidMetadata { + version: Some(Version::SortRand), + }; + assert_eq!(metadata.to_string(), "v7"); + } + + #[rstest] + #[case::non_nullable(Nullability::NonNullable)] + #[case::nullable(Nullability::Nullable)] + fn validate_correct_storage_dtype(#[case] nullability: Nullability) -> VortexResult<()> { + let metadata = UuidMetadata::default(); + let storage_dtype = uuid_storage_dtype(nullability); + ExtDType::try_with_vtable(Uuid, metadata, storage_dtype)?; + Ok(()) + } + + #[test] + fn validate_rejects_wrong_list_size() { + let storage_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), + 8, + Nullability::NonNullable, + ); + assert!(ExtDType::try_with_vtable(Uuid, UuidMetadata::default(), storage_dtype).is_err()); + } + + #[test] + fn validate_rejects_wrong_element_type() { + let storage_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)), + UUID_BYTE_LEN as u32, + Nullability::NonNullable, + ); + assert!(ExtDType::try_with_vtable(Uuid, UuidMetadata::default(), storage_dtype).is_err()); + } + + #[test] + fn validate_rejects_nullable_elements() { + let storage_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U8, Nullability::Nullable)), + UUID_BYTE_LEN as u32, + Nullability::NonNullable, + ); + assert!(ExtDType::try_with_vtable(Uuid, UuidMetadata::default(), storage_dtype).is_err()); + } + + #[test] + fn validate_rejects_non_fsl() { + let storage_dtype = DType::Primitive(PType::U8, Nullability::NonNullable); + assert!(ExtDType::try_with_vtable(Uuid, UuidMetadata::default(), storage_dtype).is_err()); + } + + #[test] + fn unpack_native_uuid() -> VortexResult<()> { + let expected = uuid::Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000") + .map_err(|e| vortex_error::vortex_err!("{e}"))?; + + let ext_dtype = ExtDType::try_new( + UuidMetadata::default(), + uuid_storage_dtype(Nullability::NonNullable), + )?; + let children: Vec = expected + .as_bytes() + .iter() + .map(|&b| Scalar::primitive(b, Nullability::NonNullable)) + .collect(); + let storage_scalar = Scalar::fixed_size_list( + DType::Primitive(PType::U8, Nullability::NonNullable), + children, + Nullability::NonNullable, + ); + + let storage_value = storage_scalar + .value() + .ok_or_else(|| vortex_error::vortex_err!("expected non-null scalar"))?; + let result = Uuid.unpack_native(&ext_dtype, storage_value)?; + assert_eq!(result, expected); + assert_eq!(result.to_string(), "550e8400-e29b-41d4-a716-446655440000"); + Ok(()) + } + + #[test] + fn unpack_native_rejects_version_mismatch() -> VortexResult<()> { + // This is a v4 UUID. + let v4_uuid = uuid::Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000") + .map_err(|e| vortex_error::vortex_err!("{e}"))?; + assert_eq!(v4_uuid.get_version(), Some(Version::Random)); + + // Metadata says v7, but the UUID is v4. + let ext_dtype = ExtDType::try_with_vtable( + Uuid, + UuidMetadata { + version: Some(Version::SortRand), + }, + uuid_storage_dtype(Nullability::NonNullable), + )?; + let children: Vec = v4_uuid + .as_bytes() + .iter() + .map(|&b| Scalar::primitive(b, Nullability::NonNullable)) + .collect(); + let storage_scalar = Scalar::fixed_size_list( + DType::Primitive(PType::U8, Nullability::NonNullable), + children, + Nullability::NonNullable, + ); + + let storage_value = storage_scalar + .value() + .ok_or_else(|| vortex_error::vortex_err!("expected non-null scalar"))?; + assert!(Uuid.unpack_native(&ext_dtype, storage_value).is_err()); + Ok(()) + } + + /// Builds a [`ScalarValue`] for a UUID's 16 bytes, suitable for passing to `unpack_native`. + fn uuid_storage_scalar(uuid: &uuid::Uuid) -> ScalarValue { + let children: Vec = uuid + .as_bytes() + .iter() + .map(|&b| Scalar::primitive(b, Nullability::NonNullable)) + .collect(); + let scalar = Scalar::fixed_size_list( + DType::Primitive(PType::U8, Nullability::NonNullable), + children, + Nullability::NonNullable, + ); + scalar.value().unwrap().clone() + } + + #[test] + fn unpack_native_accepts_matching_version() -> VortexResult<()> { + // This is a v4 UUID. + let v4_uuid = uuid::Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000") + .map_err(|e| vortex_error::vortex_err!("{e}"))?; + + let ext_dtype = ExtDType::try_new( + UuidMetadata { + version: Some(Version::Random), + }, + uuid_storage_dtype(Nullability::NonNullable), + ) + .unwrap(); + let storage_value = uuid_storage_scalar(&v4_uuid); + + let result = Uuid.unpack_native(&ext_dtype, &storage_value)?; + assert_eq!(result, v4_uuid); + Ok(()) + } + + #[test] + fn unpack_native_any_version_accepts_all() -> VortexResult<()> { + // A v4 UUID should be accepted when metadata has no version constraint. + let v4_uuid = uuid::Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000") + .map_err(|e| vortex_error::vortex_err!("{e}"))?; + + let ext_dtype = ExtDType::try_new( + UuidMetadata::default(), + uuid_storage_dtype(Nullability::NonNullable), + ) + .unwrap(); + let storage_value = uuid_storage_scalar(&v4_uuid); + + let result = Uuid.unpack_native(&ext_dtype, &storage_value)?; + assert_eq!(result, v4_uuid); + Ok(()) + } + + fn uuid_storage_dtype(nullability: Nullability) -> DType { + DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), + UUID_BYTE_LEN as u32, + nullability, + ) + } +} diff --git a/vortex-array/src/scalar/arrow.rs b/vortex-array/src/scalar/arrow.rs index bce11e27e0f..45672e745fa 100644 --- a/vortex-array/src/scalar/arrow.rs +++ b/vortex-array/src/scalar/arrow.rs @@ -466,16 +466,14 @@ mod tests { fn validate_dtype( &self, - _options: &Self::Metadata, - _storage_dtype: &DType, + _ext_dtype: &crate::dtype::extension::ExtDType, ) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _ext_dtype: &'a crate::dtype::extension::ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") diff --git a/vortex-array/src/scalar/tests/casting.rs b/vortex-array/src/scalar/tests/casting.rs index 72b44a5ed55..e5677f7c1f9 100644 --- a/vortex-array/src/scalar/tests/casting.rs +++ b/vortex-array/src/scalar/tests/casting.rs @@ -43,18 +43,13 @@ mod tests { Ok(0) } - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, _ext_dtype: &ExtDType) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _ext_dtype: &'a ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") @@ -264,18 +259,13 @@ mod tests { vortex_bail!("not implemented") } - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, _ext_dtype: &ExtDType) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _ext_dtype: &'a ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") @@ -327,18 +317,13 @@ mod tests { vortex_bail!("not implemented") } - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, _ext_dtype: &ExtDType) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _ext_dtype: &'a ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") diff --git a/vortex-array/src/scalar/typed_view/extension/tests.rs b/vortex-array/src/scalar/typed_view/extension/tests.rs index 52a4127e12c..4801e8e6e0c 100644 --- a/vortex-array/src/scalar/typed_view/extension/tests.rs +++ b/vortex-array/src/scalar/typed_view/extension/tests.rs @@ -32,18 +32,13 @@ impl ExtVTable for TestI32Ext { Ok(EmptyMetadata) } - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, _ext_dtype: &ExtDType) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _ext_dtype: &'a ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") @@ -121,18 +116,13 @@ fn test_ext_scalar_partial_ord_different_types() { Ok(EmptyMetadata) } - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, _ext_dtype: &ExtDType) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _ext_dtype: &'a ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") @@ -309,18 +299,13 @@ fn test_ext_scalar_with_metadata() { vortex_bail!("not implemented") } - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, _ext_dtype: &ExtDType) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _ext_dtype: &'a ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index 3d5eea1b500..a6c4a5a463f 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -3,6 +3,7 @@ //! Type-erased scalar function ([`ScalarFnRef`]). +use std::any::type_name; use std::fmt::Debug; use std::fmt::Display; use std::fmt::Formatter; @@ -12,6 +13,7 @@ use std::sync::Arc; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_err; use vortex_utils::debug_with::DebugWith; use crate::ArrayRef; @@ -34,7 +36,7 @@ use crate::scalar_fn::fns::not::Not; use crate::scalar_fn::options::ScalarFnOptions; use crate::scalar_fn::signature::ScalarFnSignature; use crate::scalar_fn::typed::DynScalarFn; -use crate::scalar_fn::typed::ScalarFnInner; +use crate::scalar_fn::typed::ScalarFn; /// A type-erased scalar function, pairing a vtable with bound options behind a trait object. /// @@ -54,22 +56,15 @@ impl ScalarFnRef { /// Returns whether the scalar function is of the given vtable type. pub fn is(&self) -> bool { - self.0.as_any().is::>() + self.0.as_any().is::>() } /// Returns the typed options for this scalar function if it matches the given vtable type. pub fn as_opt(&self) -> Option<&V::Options> { - self.downcast_inner::().map(|inner| &inner.options) - } - - /// Returns a reference to the typed vtable if it matches the given vtable type. - pub fn vtable_ref(&self) -> Option<&V> { - self.downcast_inner::().map(|inner| &inner.vtable) - } - - /// Downcast the inner to the concrete `ScalarFnInner`. - fn downcast_inner(&self) -> Option<&ScalarFnInner> { - self.0.as_any().downcast_ref::>() + self.0 + .as_any() + .downcast_ref::>() + .map(|sf| sf.options()) } /// Returns the typed options for this scalar function if it matches the given vtable type. @@ -82,6 +77,40 @@ impl ScalarFnRef { .vortex_expect("Expression options type mismatch") } + /// Downcast to the concrete [`ScalarFn`]. + /// + /// Returns `Err(self)` if the downcast fails. + pub fn try_downcast(self) -> Result>, ScalarFnRef> { + if self.0.as_any().is::>() { + let ptr = Arc::into_raw(self.0) as *const ScalarFn; + Ok(unsafe { Arc::from_raw(ptr) }) + } else { + Err(self) + } + } + + /// Downcast to the concrete [`ScalarFn`]. + /// + /// # Panics + /// + /// Panics if the downcast fails. + pub fn downcast(self) -> Arc> { + self.try_downcast::() + .map_err(|this| { + vortex_err!( + "Failed to downcast ScalarFnRef {} to {}", + this.0.id(), + type_name::(), + ) + }) + .vortex_expect("Failed to downcast ScalarFnRef") + } + + /// Try to downcast into a typed [`ScalarFn`]. + pub fn downcast_ref(&self) -> Option<&ScalarFn> { + self.0.as_any().downcast_ref::>() + } + /// The type-erased options for this scalar function. pub fn options(&self) -> ScalarFnOptions<'_> { ScalarFnOptions { inner: &*self.0 } diff --git a/vortex-array/src/scalar_fn/mod.rs b/vortex-array/src/scalar_fn/mod.rs index 017c71bfe99..945c37262de 100644 --- a/vortex-array/src/scalar_fn/mod.rs +++ b/vortex-array/src/scalar_fn/mod.rs @@ -36,11 +36,11 @@ pub type ScalarFnId = ArcRef; /// Private module to seal [`typed::DynScalarFn`]. mod sealed { use crate::scalar_fn::ScalarFnVTable; - use crate::scalar_fn::typed::ScalarFnInner; + use crate::scalar_fn::typed::ScalarFn; /// Marker trait to prevent external implementations of [`super::typed::DynScalarFn`]. pub(crate) trait Sealed {} /// This can be the **only** implementor for [`super::typed::DynScalarFn`]. - impl Sealed for ScalarFnInner {} + impl Sealed for ScalarFn {} } diff --git a/vortex-array/src/scalar_fn/typed.rs b/vortex-array/src/scalar_fn/typed.rs index 2d6f684ed61..ce3a3e377a4 100644 --- a/vortex-array/src/scalar_fn/typed.rs +++ b/vortex-array/src/scalar_fn/typed.rs @@ -4,7 +4,7 @@ //! Typed and inner representations of scalar functions. //! //! - [`ScalarFn`]: The public typed wrapper, parameterized by a concrete [`ScalarFnVTable`]. -//! - [`ScalarFnInner`]: The private inner struct that holds the vtable + options. +//! - [`ScalarFn`]: The private inner struct that holds the vtable + options. //! - [`DynScalarFn`]: The private sealed trait for type-erased dispatch (bound, options in self). use std::any::Any; @@ -35,9 +35,43 @@ use crate::scalar_fn::ScalarFnRef; use crate::scalar_fn::ScalarFnVTable; use crate::scalar_fn::SimplifyCtx; +/// A typed scalar function instance, parameterized by a concrete [`ScalarFnVTable`]. +/// +/// You can construct one via [`new()`], and erase the type with [`erased()`] to obtain a +/// [`ScalarFnRef`]. +/// +/// [`new()`]: ScalarFn::new +/// [`erased()`]: ScalarFn::erased +pub struct ScalarFn { + vtable: V, + options: V::Options, +} + +impl ScalarFn { + /// Create a new typed scalar function instance. + pub fn new(vtable: V, options: V::Options) -> Self { + Self { vtable, options } + } + + /// Returns a reference to the vtable. + pub fn vtable(&self) -> &V { + &self.vtable + } + + /// Returns a reference to the options. + pub fn options(&self) -> &V::Options { + &self.options + } + + /// Erase the concrete type information, returning a type-erased [`ScalarFnRef`]. + pub fn erased(self) -> ScalarFnRef { + ScalarFnRef(Arc::new(self)) + } +} + /// An object-safe, sealed trait for bound scalar function dispatch. /// -/// Options are stored inside the implementing [`ScalarFnInner`], not passed externally. +/// Options are stored inside the implementing [`ScalarFn`], not passed externally. /// This is the sole trait behind [`ScalarFnRef`]'s `Arc`. pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { fn as_any(&self) -> &dyn Any; @@ -86,16 +120,7 @@ pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { fn options_debug(&self, f: &mut Formatter<'_>) -> fmt::Result; } -/// The private inner representation of a bound scalar function, pairing a vtable with its options. -/// -/// This is the sole implementor of [`DynScalarFn`], enabling [`ScalarFnRef`] to safely downcast -/// back to the concrete vtable type via [`Any`]. -pub(super) struct ScalarFnInner { - pub(super) vtable: V, - pub(super) options: V::Options, -} - -impl DynScalarFn for ScalarFnInner { +impl DynScalarFn for ScalarFn { #[inline(always)] fn as_any(&self) -> &dyn Any { self @@ -232,34 +257,3 @@ impl DynScalarFn for ScalarFnInner { Debug::fmt(&self.options, f) } } - -/// A typed scalar function instance, parameterized by a concrete [`ScalarFnVTable`]. -/// -/// You can construct one via [`new()`], and erase the type with [`erased()`] to obtain a -/// [`ScalarFnRef`]. -/// -/// [`new()`]: ScalarFn::new -/// [`erased()`]: ScalarFn::erased -pub struct ScalarFn(pub(super) Arc>); - -impl ScalarFn { - /// Create a new typed scalar function instance. - pub fn new(vtable: V, options: V::Options) -> Self { - Self(Arc::new(ScalarFnInner { vtable, options })) - } - - /// Returns a reference to the vtable. - pub fn vtable(&self) -> &V { - &self.0.vtable - } - - /// Returns a reference to the options. - pub fn options(&self) -> &V::Options { - &self.0.options - } - - /// Erase the concrete type information, returning a type-erased [`ScalarFnRef`]. - pub fn erased(self) -> ScalarFnRef { - ScalarFnRef(self.0) - } -} diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 8986b172350..2b88a142144 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -36,7 +36,7 @@ use crate::scalar_fn::fns::binary::Binary; use crate::scalar_fn::fns::operators::Operator; /// Validity information for an array -#[derive(Clone, Debug)] +#[derive(Clone)] pub enum Validity { /// Items *can't* be null NonNullable, @@ -50,6 +50,17 @@ pub enum Validity { Array(ArrayRef), } +impl Debug for Validity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NonNullable => write!(f, "NonNullable"), + Self::AllValid => write!(f, "AllValid"), + Self::AllInvalid => write!(f, "AllInvalid"), + Self::Array(arr) => write!(f, "SomeValid({})", arr.as_ref().display_values()), + } + } +} + impl Validity { /// Make a step towards canonicalising validity if necessary pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult { diff --git a/vortex-array/src/vtable/mod.rs b/vortex-array/src/vtable/mod.rs index a038a83639a..013821fb719 100644 --- a/vortex-array/src/vtable/mod.rs +++ b/vortex-array/src/vtable/mod.rs @@ -185,8 +185,8 @@ pub trait VTable: 'static + Sized + Send + Sync + Debug { /// do next. /// /// Instead of recursively executing children, implementations should return - /// [`ExecutionStep::ExecuteChild(i)`] to request that the scheduler execute a child first, - /// or [`ExecutionStep::Done(result)`] when the + /// [`ExecutionStep::ExecuteChild`] to request that the scheduler execute a child first, + /// or [`ExecutionStep::Done`] when the /// encoding can produce a result directly. /// /// Array execution is designed such that repeated execution of an array will eventually diff --git a/vortex-btrblocks/benches/compress.rs b/vortex-btrblocks/benches/compress.rs index f0aa6c7101d..1e0c9f2e1c2 100644 --- a/vortex-btrblocks/benches/compress.rs +++ b/vortex-btrblocks/benches/compress.rs @@ -21,11 +21,11 @@ mod benchmarks { fn make_clickbench_window_name() -> ArrayRef { // A test that's meant to mirror the WindowName column from ClickBench. - let mut values = buffer_mut![-1i32; 1_000_000]; + let mut values = buffer_mut![-1i32; 65_536]; let mut visited = HashSet::new(); let mut rng = StdRng::seed_from_u64(1u64); while visited.len() < 223 { - let random = (rng.next_u32() as usize) % 1_000_000; + let random = (rng.next_u32() as usize) % 65_536; if visited.contains(&random) { continue; } diff --git a/vortex-btrblocks/src/canonical_compressor.rs b/vortex-btrblocks/src/canonical_compressor.rs index 33c3f3b1df9..db850f0bee9 100644 --- a/vortex-btrblocks/src/canonical_compressor.rs +++ b/vortex-btrblocks/src/canonical_compressor.rs @@ -299,7 +299,7 @@ impl CanonicalCompressor for BtrBlocksCompressor { } // Compress the underlying storage array. - let compressed_storage = self.compress(ext_array.storage())?; + let compressed_storage = self.compress(ext_array.storage_array())?; Ok( ExtensionArray::new(ext_array.ext_dtype().clone(), compressed_storage) diff --git a/vortex-buffer/public-api.lock b/vortex-buffer/public-api.lock index b695c9cf14d..17c1bff5ea9 100644 --- a/vortex-buffer/public-api.lock +++ b/vortex-buffer/public-api.lock @@ -336,6 +336,10 @@ impl core::fmt::Debug for vortex_buffer::BitBuffer pub fn vortex_buffer::BitBuffer::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +impl core::fmt::Display for vortex_buffer::BitBuffer + +pub fn vortex_buffer::BitBuffer::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + impl core::iter::traits::collect::FromIterator for vortex_buffer::BitBuffer pub fn vortex_buffer::BitBuffer::from_iter>(iter: T) -> Self @@ -448,6 +452,10 @@ pub fn vortex_buffer::BitBufferMut::empty() -> Self pub fn vortex_buffer::BitBufferMut::false_count(&self) -> usize +pub fn vortex_buffer::BitBufferMut::fill_range(&mut self, start: usize, end: usize, value: bool) + +pub unsafe fn vortex_buffer::BitBufferMut::fill_range_unchecked(&mut self, start: usize, end: usize, value: bool) + pub fn vortex_buffer::BitBufferMut::freeze(self) -> vortex_buffer::BitBuffer pub fn vortex_buffer::BitBufferMut::from_buffer(buffer: vortex_buffer::ByteBufferMut, offset: usize, len: usize) -> Self @@ -704,7 +712,9 @@ pub type vortex_buffer::Buffer::Target = [T] pub fn vortex_buffer::Buffer::deref(&self) -> &Self::Target -pub struct vortex_buffer::BufferIterator +pub struct vortex_buffer::BufferIterator + +impl core::iter::traits::exact_size::ExactSizeIterator for vortex_buffer::BufferIterator impl core::iter::traits::iterator::Iterator for vortex_buffer::BufferIterator diff --git a/vortex-buffer/src/bit/buf.rs b/vortex-buffer/src/bit/buf.rs index f0539c839ca..87b4054774d 100644 --- a/vortex-buffer/src/bit/buf.rs +++ b/vortex-buffer/src/bit/buf.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::fmt::Display; +use std::fmt::Formatter; +use std::fmt::Result as FmtResult; use std::ops::BitAnd; use std::ops::BitOr; use std::ops::BitXor; @@ -22,6 +25,7 @@ use crate::bit::get_bit_unchecked; use crate::bit::ops::bitwise_binary_op; use crate::bit::ops::bitwise_unary_op; use crate::buffer; +use crate::trusted_len::TrustedLenExt; /// An immutable bitset stored as a packed byte buffer. #[derive(Debug, Clone, Eq)] @@ -35,6 +39,18 @@ pub struct BitBuffer { len: usize, } +const LIMIT_LEN: usize = 16; +impl Display for BitBuffer { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + let limit = f.precision().unwrap_or(LIMIT_LEN); + let buf: Vec = self.into_iter().take(limit).collect(); + f.debug_struct("BitBuffer") + .field("len", &self.len) + .field("buffer", &buf) + .finish() + } +} + impl PartialEq for BitBuffer { fn eq(&self, other: &Self) -> bool { if self.len != other.len { @@ -85,7 +101,11 @@ impl BitBuffer { // Slice the buffer to ensure the offset is within the first byte let byte_offset = offset / 8; let offset = offset % 8; - let buffer = buffer.slice(byte_offset..); + let buffer = if byte_offset != 0 { + buffer.slice(byte_offset..) + } else { + buffer + }; Self { buffer, @@ -327,7 +347,12 @@ impl BitBuffer { self.len, ); } - bitwise_unary_op(self, |a| a) + // Use Chunk iterator here to reset offset to 0 + let iter = self.chunks().iter_padded(); + let iter = unsafe { iter.trusted_len() }; + let result = Buffer::::from_trusted_len_iter(iter).into_byte_buffer(); + + BitBuffer::new(result, self.len()) } } diff --git a/vortex-buffer/src/bit/buf_mut.rs b/vortex-buffer/src/bit/buf_mut.rs index 5d16c783af9..da3c1bf3d5c 100644 --- a/vortex-buffer/src/bit/buf_mut.rs +++ b/vortex-buffer/src/bit/buf_mut.rs @@ -16,6 +16,61 @@ use crate::bit::set_bit_unchecked; use crate::bit::unset_bit_unchecked; use crate::buffer_mut; +/// Sets all bits in the bit-range `[start_bit, end_bit)` of `slice` to `value`. +#[inline(always)] +fn fill_bits(slice: &mut [u8], start_bit: usize, end_bit: usize, value: bool) { + if start_bit >= end_bit { + return; + } + + let fill_byte: u8 = if value { 0xFF } else { 0x00 }; + + let start_byte = start_bit / 8; + let start_rem = start_bit % 8; + let end_byte = end_bit / 8; + let end_rem = end_bit % 8; + + if start_byte == end_byte { + // All bits are in the same byte + let mask = ((1u8 << (end_rem - start_rem)) - 1) << start_rem; + if value { + slice[start_byte] |= mask; + } else { + slice[start_byte] &= !mask; + } + } else { + // First partial byte + if start_rem != 0 { + let mask = !((1u8 << start_rem) - 1); + if value { + slice[start_byte] |= mask; + } else { + slice[start_byte] &= !mask; + } + } + + // Middle bytes + let fill_start = if start_rem != 0 { + start_byte + 1 + } else { + start_byte + }; + if fill_start < end_byte { + slice[fill_start..end_byte].fill(fill_byte); + } + + // Last partial byte + if end_rem != 0 { + let mask = (1u8 << end_rem) - 1; + if value { + slice[end_byte] |= mask; + } else { + slice[end_byte] &= !mask; + } + } + } +} + /// A mutable bitset buffer that allows random access to individual bits for set and get. /// /// @@ -398,13 +453,13 @@ impl BitBufferMut { /// the length will be incremented by `n`. /// /// Panics if the buffer does not have `n` slots left. + #[inline] pub fn append_n(&mut self, value: bool, n: usize) { if n == 0 { return; } - let start_bit_pos = self.offset + self.len; - let end_bit_pos = start_bit_pos + n; + let end_bit_pos = self.offset + self.len + n; let required_bytes = end_bit_pos.div_ceil(8); // Ensure buffer has enough bytes @@ -412,58 +467,38 @@ impl BitBufferMut { self.buffer.push_n(0x00, required_bytes - self.buffer.len()); } - let fill_byte = if value { 0xFF } else { 0x00 }; - - // Calculate byte positions - let start_byte = start_bit_pos / 8; - let start_bit = start_bit_pos % 8; - let end_byte = end_bit_pos / 8; - let end_bit = end_bit_pos % 8; - - let slice = self.buffer.as_mut_slice(); - - if start_byte == end_byte { - // All bits are in the same byte - let mask = ((1u8 << (end_bit - start_bit)) - 1) << start_bit; - if value { - slice[start_byte] |= mask; - } else { - slice[start_byte] &= !mask; - } - } else { - // Fill the first partial byte - if start_bit != 0 { - let mask = !((1u8 << start_bit) - 1); - if value { - slice[start_byte] |= mask; - } else { - slice[start_byte] &= !mask; - } - } + let start = self.len; + self.len += n; + self.fill_range(start, self.len, value); + } - // Fill the complete middle bytes - let fill_start = if start_bit != 0 { - start_byte + 1 - } else { - start_byte - }; - let fill_end = end_byte; - if fill_start < fill_end { - slice[fill_start..fill_end].fill(fill_byte); - } + /// Sets all bits in the range `[start, end)` to `value`. + /// + /// This operates on an arbitrary range within the existing length of the buffer. + /// Panics if `end > self.len` or `start > end`. + #[inline(always)] + pub fn fill_range(&mut self, start: usize, end: usize, value: bool) { + assert!(end <= self.len, "end {end} exceeds len {}", self.len); + assert!(start <= end, "start {start} exceeds end {end}"); - // Fill the last partial byte - if end_bit != 0 { - let mask = (1u8 << end_bit) - 1; - if value { - slice[end_byte] |= mask; - } else { - slice[end_byte] &= !mask; - } - } - } + // SAFETY: assertions above guarantee start <= end <= self.len, + // so offset + end fits within the buffer. + unsafe { self.fill_range_unchecked(start, end, value) } + } - self.len += n; + /// Sets all bits in the range `[start, end)` to `value` without bounds checking. + /// + /// # Safety + /// + /// The caller must ensure that `start <= end <= self.len`. + #[inline(always)] + pub unsafe fn fill_range_unchecked(&mut self, start: usize, end: usize, value: bool) { + fill_bits( + self.buffer.as_mut_slice(), + self.offset + start, + self.offset + end, + value, + ); } /// Append a [`BitBuffer`] to this [`BitBufferMut`] @@ -612,6 +647,7 @@ impl Default for BitBufferMut { impl Not for BitBufferMut { type Output = BitBufferMut; + #[inline] fn not(mut self) -> Self::Output { ops::bitwise_unary_op_mut(&mut self, |b| !b); self diff --git a/vortex-buffer/src/bit/ops.rs b/vortex-buffer/src/bit/ops.rs index d44224cc1d3..7f07a39515a 100644 --- a/vortex-buffer/src/bit/ops.rs +++ b/vortex-buffer/src/bit/ops.rs @@ -6,15 +6,17 @@ use crate::BitBufferMut; use crate::Buffer; use crate::trusted_len::TrustedLenExt; +#[inline] pub(super) fn bitwise_unary_op u64>(buffer: &BitBuffer, op: F) -> BitBuffer { - let iter = buffer.chunks().iter_padded().map(op); + let iter = buffer.unaligned_chunks().iter().map(op); let iter = unsafe { iter.trusted_len() }; let result = Buffer::::from_trusted_len_iter(iter).into_byte_buffer(); - BitBuffer::new(result, buffer.len()) + BitBuffer::new_with_offset(result, buffer.len(), buffer.offset()) } +#[inline] pub(super) fn bitwise_unary_op_mut u64>(buffer: &mut BitBufferMut, mut op: F) { let slice_mut = buffer.as_mut_slice(); @@ -54,6 +56,25 @@ pub(super) fn bitwise_binary_op u64>( ) -> BitBuffer { assert_eq!(left.len(), right.len()); + // If the buffers are aligned, we can use the fast path. + if left.offset().is_multiple_of(8) && right.offset().is_multiple_of(8) { + let left_chunks = left.unaligned_chunks(); + let right_chunks = right.unaligned_chunks(); + if left_chunks.lead_padding() == 0 + && left_chunks.trailing_padding() == 0 + && right_chunks.lead_padding() == 0 + && right_chunks.trailing_padding() == 0 + { + let iter = left_chunks + .iter() + .zip(right_chunks.iter()) + .map(|(l, r)| op(l, r)); + let iter = unsafe { iter.trusted_len() }; + let result = Buffer::::from_trusted_len_iter(iter).into_byte_buffer(); + return BitBuffer::new(result, left.len()); + } + } + let iter = left .chunks() .iter_padded() diff --git a/vortex-buffer/src/buffer.rs b/vortex-buffer/src/buffer.rs index f941db6dbec..5a7de66f4fd 100644 --- a/vortex-buffer/src/buffer.rs +++ b/vortex-buffer/src/buffer.rs @@ -645,9 +645,11 @@ impl Buf for ByteBuffer { } /// Owned iterator over a [`Buffer`]. -pub struct BufferIterator { - buffer: Buffer, - index: usize, +pub struct BufferIterator { + // Keep the buffer alive for the duration of the iteration. + _buffer: Buffer, + ptr: *const T, + end: *const T, } impl Iterator for BufferIterator { @@ -655,29 +657,37 @@ impl Iterator for BufferIterator { #[inline] fn next(&mut self) -> Option { - (self.index < self.buffer.len()).then(move || { - let value = self.buffer[self.index]; - self.index += 1; - value - }) + if self.ptr == self.end { + None + } else { + // SAFETY: ptr is within the buffer and has not reached end. + let value = unsafe { self.ptr.read() }; + self.ptr = unsafe { self.ptr.add(1) }; + Some(value) + } } #[inline] fn size_hint(&self) -> (usize, Option) { - let remaining = self.buffer.len() - self.index; + let remaining = unsafe { self.end.offset_from(self.ptr) } as usize; (remaining, Some(remaining)) } } +impl ExactSizeIterator for BufferIterator {} + impl IntoIterator for Buffer { type Item = T; type IntoIter = BufferIterator; #[inline] fn into_iter(self) -> Self::IntoIter { + let ptr = self.as_slice().as_ptr(); + let end = unsafe { ptr.add(self.len()) }; BufferIterator { - buffer: self, - index: 0, + _buffer: self, + ptr, + end, } } } diff --git a/vortex-cuda/build.rs b/vortex-cuda/build.rs index d4b41eae6c7..88c1af961ed 100644 --- a/vortex-cuda/build.rs +++ b/vortex-cuda/build.rs @@ -138,7 +138,7 @@ fn nvcc_compile_ptx( .join(cu_path.file_name().unwrap()) .with_extension("ptx"); - cmd.arg("-std=c++17") + cmd.arg("-std=c++20") .arg("-arch=native") // Flags forwarded to Clang. .arg("--compiler-options=-Wall -Wextra -Wpedantic -Werror") diff --git a/vortex-cuda/cub/build.rs b/vortex-cuda/cub/build.rs index e9c98a1b214..16b758a57dd 100644 --- a/vortex-cuda/cub/build.rs +++ b/vortex-cuda/cub/build.rs @@ -57,7 +57,7 @@ fn is_cuda_available() -> bool { fn compile_shared_library(kernel_dir: &Path, sources: &[PathBuf], out_dir: &Path) { let lib_path = out_dir.join("libvortex_cub.so"); let mut cmd = Command::new("nvcc"); - cmd.args(["-std=c++17", "-arch=native"]); + cmd.args(["-std=c++20", "-arch=native"]); if env::var("PROFILE").unwrap() == "debug" { cmd.args(["-O0", "-g", "-G", "-lineinfo"]); diff --git a/vortex-cuda/kernels/src/dynamic_dispatch.cu b/vortex-cuda/kernels/src/dynamic_dispatch.cu index 9168a8580ab..ef9c824e187 100644 --- a/vortex-cuda/kernels/src/dynamic_dispatch.cu +++ b/vortex-cuda/kernels/src/dynamic_dispatch.cu @@ -35,7 +35,7 @@ __device__ inline uint64_t upper_bound(const T *data, uint64_t len, uint64_t val /// compressed or raw data from global memory and writes decoded elements into /// the stage's shared memory region. /// -/// @param input Global memory pointer to the stage's encoded input data +/// @param input Global memory pointer to the stage's encoded input data /// @param smem_output Shared memory pointer where decoded elements are written /// @param chunk_start Starting index of the chunk to process (block-relative for output stage) /// @param chunk_len Number of elements to produce (may be < ELEMENTS_PER_BLOCK for tail blocks) @@ -44,7 +44,7 @@ __device__ inline uint64_t upper_bound(const T *data, uint64_t len, uint64_t val /// to resolve offsets to ends/values decoded by earlier stages template __device__ inline void dynamic_source_op(const T *__restrict input, - T *__restrict smem_output, + T *__restrict &smem_output, uint64_t chunk_start, uint32_t chunk_len, const struct SourceOp &source_op, @@ -57,7 +57,10 @@ __device__ inline void dynamic_source_op(const T *__restrict input, constexpr uint32_t LANES_PER_FL_BLOCK = FL_CHUNK_SIZE / T_BITS; const uint32_t bit_width = source_op.params.bitunpack.bit_width; const uint32_t packed_words_per_fl_block = LANES_PER_FL_BLOCK * bit_width; - const uint64_t first_fl_block = chunk_start / FL_CHUNK_SIZE; + + const uint32_t element_offset = source_op.params.bitunpack.element_offset; + const uint32_t smem_within_offset = (chunk_start + element_offset) % FL_CHUNK_SIZE; + const uint64_t first_fl_block = (chunk_start + element_offset) / FL_CHUNK_SIZE; // FL blocks must divide evenly. Otherwise, the last unpack would overflow smem. static_assert((ELEMENTS_PER_BLOCK % FL_CHUNK_SIZE) == 0); @@ -65,7 +68,7 @@ __device__ inline void dynamic_source_op(const T *__restrict input, const auto div_ceil = [](auto a, auto b) { return (a + b - 1) / b; }; - const uint32_t num_fl_chunks = div_ceil(chunk_len, FL_CHUNK_SIZE); + const uint32_t num_fl_chunks = div_ceil(chunk_len + smem_within_offset, FL_CHUNK_SIZE); for (uint32_t chunk_idx = 0; chunk_idx < num_fl_chunks; ++chunk_idx) { const T *packed_chunk = input + (first_fl_block + chunk_idx) * packed_words_per_fl_block; @@ -75,7 +78,8 @@ __device__ inline void dynamic_source_op(const T *__restrict input, bit_unpack_lane(packed_chunk, smem_lane, 0, lane, bit_width); } } - break; + smem_output += smem_within_offset; + return; } case SourceOp::LOAD: { @@ -83,7 +87,7 @@ __device__ inline void dynamic_source_op(const T *__restrict input, for (uint32_t i = threadIdx.x; i < chunk_len; i += blockDim.x) { smem_output[i] = input[chunk_start + i]; } - break; + return; } case SourceOp::RUNEND: { @@ -107,7 +111,7 @@ __device__ inline void dynamic_source_op(const T *__restrict input, smem_output[i] = values[min(current_run, num_runs - 1)]; } - break; + return; } default: @@ -273,6 +277,18 @@ __device__ void execute_stage(const struct Stage &stage, __syncthreads(); } +/// Computes the number of elements to process in an output tile. +/// +/// Each tile decodes exactly one FL block == SMEM_TILE_SIZE elements into +/// shared memory. In case BITUNPACK is sliced, we need to account for the +/// sub-byte element offset. +__device__ inline uint32_t output_tile_len(const struct Stage &stage, uint32_t block_len, uint32_t tile_off) { + const uint32_t element_offset = (tile_off == 0 && stage.source.op_code == SourceOp::BITUNPACK) + ? stage.source.params.bitunpack.element_offset + : 0; + return min(SMEM_TILE_SIZE - element_offset, block_len - tile_off); +} + /// Entry point of the dynamic dispatch kernel. /// /// Executes the plan's stages in order: @@ -285,9 +301,9 @@ __device__ void execute_stage(const struct Stage &stage, /// @param array_len Total number of elements to produce /// @param plan Device pointer to the dispatch plan template -__device__ void dynamic_dispatch_impl(T *__restrict output, - uint64_t array_len, - const struct DynamicDispatchPlan *__restrict plan) { +__device__ void dynamic_dispatch(T *__restrict output, + uint64_t array_len, + const struct DynamicDispatchPlan *__restrict plan) { // Dynamically-sized shared memory: The host computes the exact byte count // needed to hold all stage outputs that must coexist simultaneously, and @@ -310,21 +326,20 @@ __device__ void dynamic_dispatch_impl(T *__restrict output, execute_stage(stage, smem_base, 0, stage.len, smem_output, 0); } - // Output stage: process in SMEM_TILE_SIZE tiles to reduce smem footprint. - // Each tile decodes into the same smem region and writes to global memory. const struct Stage &output_stage = smem_plan.stages[last]; const uint64_t block_start = static_cast(blockIdx.x) * ELEMENTS_PER_BLOCK; const uint64_t block_end = min(block_start + ELEMENTS_PER_BLOCK, array_len); const uint32_t block_len = static_cast(block_end - block_start); - for (uint32_t tile_off = 0; tile_off < block_len; tile_off += SMEM_TILE_SIZE) { - const uint32_t tile_len = min(SMEM_TILE_SIZE, block_len - tile_off); + for (uint32_t tile_off = 0; tile_off < block_len;) { + const uint32_t tile_len = output_tile_len(output_stage, block_len, tile_off); execute_stage(output_stage, smem_base, block_start + tile_off, tile_len, output, block_start + tile_off); + tile_off += tile_len; } } @@ -334,7 +349,7 @@ __device__ void dynamic_dispatch_impl(T *__restrict output, Type *__restrict output, \ uint64_t array_len, \ const struct DynamicDispatchPlan *__restrict plan) { \ - dynamic_dispatch_impl(output, array_len, plan); \ + dynamic_dispatch(output, array_len, plan); \ } FOR_EACH_UNSIGNED_INT(GENERATE_DYNAMIC_DISPATCH_KERNEL) diff --git a/vortex-cuda/kernels/src/dynamic_dispatch.h b/vortex-cuda/kernels/src/dynamic_dispatch.h index f8fbeaf6c13..9f7dc122f1b 100644 --- a/vortex-cuda/kernels/src/dynamic_dispatch.h +++ b/vortex-cuda/kernels/src/dynamic_dispatch.h @@ -44,11 +44,13 @@ union SourceParams { /// Unpack bit-packed data using FastLanes layout. struct BitunpackParams { uint8_t bit_width; + uint32_t element_offset; // Sub-byte offset } bitunpack; /// Copy elements verbatim from global memory to shared memory. + /// The input pointer is pre-adjusted on the host to account for slicing. struct LoadParams { - uint8_t _padding; + uint8_t _placeholder; } load; /// Decode run-end encoding using ends and values already in shared memory. diff --git a/vortex-cuda/src/arrow/canonical.rs b/vortex-cuda/src/arrow/canonical.rs index 0cfc223f8cd..65115cbf2f3 100644 --- a/vortex-cuda/src/arrow/canonical.rs +++ b/vortex-cuda/src/arrow/canonical.rs @@ -116,7 +116,7 @@ fn export_canonical( vortex_bail!("only support temporal extension types currently"); } - let values = extension.storage().to_primitive(); + let values = extension.storage_array().to_primitive(); let len = extension.len(); let PrimitiveArrayParts { diff --git a/vortex-cuda/src/canonical.rs b/vortex-cuda/src/canonical.rs index 7ae499d9577..d4eb37e09f3 100644 --- a/vortex-cuda/src/canonical.rs +++ b/vortex-cuda/src/canonical.rs @@ -133,7 +133,7 @@ impl CanonicalCudaExt for Canonical { Canonical::Extension(ext) => { // Copy the storage array to host and rewrap in ExtensionArray. let host_storage = ext - .storage() + .storage_array() .to_canonical()? .into_host() .await? diff --git a/vortex-cuda/src/device_buffer.rs b/vortex-cuda/src/device_buffer.rs index 17bcd44f5d4..c8d2841f10a 100644 --- a/vortex-cuda/src/device_buffer.rs +++ b/vortex-cuda/src/device_buffer.rs @@ -81,8 +81,6 @@ mod private { } } -// Get it back out as a View of u8 - impl CudaDeviceBuffer { /// Creates a new CUDA device buffer from a [`CudaSlice`]. /// @@ -101,6 +99,16 @@ impl CudaDeviceBuffer { } } + /// Returns the byte offset within the allocated buffer. + pub fn offset(&self) -> usize { + self.offset + } + + /// Returns the adjusted device pointer accounting for the offset. + pub fn offset_ptr(&self) -> sys::CUdeviceptr { + self.device_ptr + self.offset as u64 + } + /// Returns a [`CudaView`] to the CUDA device buffer. pub fn as_view(&self) -> CudaView<'_, T> { // Return a new &[T] @@ -159,7 +167,7 @@ impl CudaBufferExt for BufferHandle { .as_any() .downcast_ref::() .ok_or_else(|| vortex_err!("expected CudaDeviceBuffer"))? - .device_ptr; + .offset_ptr(); Ok(ptr) } @@ -281,7 +289,7 @@ impl DeviceBuffer for CudaDeviceBuffer { /// Slices the CUDA device buffer to a subrange. /// - /// **IMPORTANT**: this is a byte range, not elements range, due to the DeviceBuffer interface. + /// This is a byte range, not elements range, due to the DeviceBuffer interface. fn slice(&self, range: Range) -> Arc { assert!( range.end <= self.len, diff --git a/vortex-cuda/src/dynamic_dispatch/mod.rs b/vortex-cuda/src/dynamic_dispatch/mod.rs index a2a431f1490..dd4e78c47af 100644 --- a/vortex-cuda/src/dynamic_dispatch/mod.rs +++ b/vortex-cuda/src/dynamic_dispatch/mod.rs @@ -30,11 +30,19 @@ unsafe impl cudarc::driver::DeviceRepr for Stage {} impl SourceOp { /// Unpack bit-packed data using FastLanes layout. - pub fn bitunpack(bit_width: u8) -> Self { + /// + /// `element_offset` (0..1023) is the sub-block position within the first + /// FastLanes block. The device pointer already accounts for buffer slicing, + /// but sub-block alignment cannot be expressed as pointer arithmetic on + /// bit-packed data, so it is passed as a kernel parameter. + pub fn bitunpack(bit_width: u8, element_offset: u16) -> Self { Self { op_code: SourceOp_SourceOpCode_BITUNPACK, params: SourceParams { - bitunpack: SourceParams_BitunpackParams { bit_width }, + bitunpack: SourceParams_BitunpackParams { + bit_width, + element_offset: u32::from(element_offset), + }, }, } } @@ -134,9 +142,8 @@ impl Stage { } } - /// Create the output stage. Uses [`SMEM_TILE_SIZE`] as the shared memory - /// region size — the kernel tiles `ELEMENTS_PER_BLOCK` elements through - /// this smaller region to reduce shared memory usage. + /// Create the output stage. The kernel tiles `ELEMENTS_PER_BLOCK` elements + /// through a [`SMEM_TILE_SIZE`] shared-memory region to reduce usage. pub fn output( input_ptr: u64, smem_offset: u32, @@ -192,6 +199,7 @@ mod tests { use cudarc::driver::DevicePtr; use cudarc::driver::LaunchConfig; use cudarc::driver::PushKernelArg; + use rstest::rstest; use vortex::array::IntoArray; use vortex::array::ToCanonical; use vortex::array::arrays::DictArray; @@ -259,7 +267,7 @@ mod tests { let plan = DynamicDispatchPlan::new([Stage::output( input_ptr, 0, - SourceOp::bitunpack(bit_width), + SourceOp::bitunpack(bit_width, 0), &scalar_ops, )]); assert_eq!(plan.stages[0].num_scalar_ops, 4); @@ -279,13 +287,13 @@ mod tests { 0xAAAA, 0, 256, - SourceOp::bitunpack(4), + SourceOp::bitunpack(4, 0), &[ScalarOp::frame_of_ref(10)], ), Stage::output( 0xBBBB, 256, - SourceOp::bitunpack(6), + SourceOp::bitunpack(6, 0), &[ScalarOp::frame_of_ref(42), ScalarOp::dict(0)], ), ]); @@ -684,4 +692,316 @@ mod tests { Ok(()) } + + #[rstest] + #[case(0, 1024)] + #[case(0, 3000)] + #[case(0, 4096)] + #[case(500, 600)] + #[case(500, 1024)] + #[case(500, 2048)] + #[case(500, 4500)] + #[case(777, 3333)] + #[case(1024, 2048)] + #[case(1024, 4096)] + #[case(1500, 3500)] + #[case(2048, 4096)] + #[case(2500, 4500)] + #[case(3333, 4444)] + #[crate::test] + fn test_sliced_primitive( + #[case] slice_start: usize, + #[case] slice_end: usize, + ) -> VortexResult<()> { + let len = 5000; + let data: Vec = (0..len).map(|i| (i * 7) % 1000).collect(); + + let prim = PrimitiveArray::new(Buffer::from(data.clone()), NonNullable); + + let sliced = prim.into_array().slice(slice_start..slice_end)?; + + let expected: Vec = data[slice_start..slice_end].to_vec(); + + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let (plan, _bufs) = build_plan(&sliced, &cuda_ctx)?; + + let actual = run_dynamic_dispatch_plan(&cuda_ctx, expected.len(), &plan)?; + assert_eq!(actual, expected); + + Ok(()) + } + + #[rstest] + #[case(0, 1024)] + #[case(0, 3000)] + #[case(0, 4096)] + #[case(500, 600)] + #[case(500, 1024)] + #[case(500, 2048)] + #[case(500, 4500)] + #[case(777, 3333)] + #[case(1024, 2048)] + #[case(1024, 4096)] + #[case(1500, 3500)] + #[case(2048, 4096)] + #[case(2500, 4500)] + #[case(3333, 4444)] + #[crate::test] + fn test_sliced_zigzag_bitpacked( + #[case] slice_start: usize, + #[case] slice_end: usize, + ) -> VortexResult<()> { + let bit_width = 10u8; + let max_val = (1u32 << bit_width) - 1; + let len = 5000; + + let raw: Vec = (0..len).map(|i| (i as u32) % max_val).collect(); + let all_decoded: Vec = raw + .iter() + .map(|&v| (v >> 1) ^ (0u32.wrapping_sub(v & 1))) + .collect(); + + let prim = PrimitiveArray::new(Buffer::from(raw), NonNullable); + let bp = BitPackedArray::encode(&prim.into_array(), bit_width)?; + let zz = ZigZagArray::try_new(bp.into_array())?; + + let sliced = zz.into_array().slice(slice_start..slice_end)?; + let expected: Vec = all_decoded[slice_start..slice_end].to_vec(); + + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let (plan, _bufs) = build_plan(&sliced, &cuda_ctx)?; + + let actual = run_dynamic_dispatch_plan(&cuda_ctx, expected.len(), &plan)?; + assert_eq!(actual, expected); + + Ok(()) + } + + #[rstest] + #[case(0, 1024)] + #[case(0, 3000)] + #[case(0, 4096)] + #[case(500, 600)] + #[case(500, 1024)] + #[case(500, 2048)] + #[case(500, 4500)] + #[case(777, 3333)] + #[case(1024, 2048)] + #[case(1024, 4096)] + #[case(1500, 3500)] + #[case(2048, 4096)] + #[case(2500, 4500)] + #[case(3333, 4444)] + #[crate::test] + fn test_sliced_dict_with_primitive_codes( + #[case] slice_start: usize, + #[case] slice_end: usize, + ) -> VortexResult<()> { + let dict_values: Vec = vec![100, 200, 300, 400, 500]; + let dict_size = dict_values.len(); + let len = 5000; + let codes: Vec = (0..len).map(|i| (i % dict_size) as u32).collect(); + + let codes_prim = PrimitiveArray::new(Buffer::from(codes.clone()), NonNullable); + let values_prim = PrimitiveArray::new(Buffer::from(dict_values.clone()), NonNullable); + let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?; + + let sliced = dict.into_array().slice(slice_start..slice_end)?; + + let expected: Vec = codes[slice_start..slice_end] + .iter() + .map(|&c| dict_values[c as usize]) + .collect(); + + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let (plan, _bufs) = build_plan(&sliced, &cuda_ctx)?; + + let actual = run_dynamic_dispatch_plan(&cuda_ctx, expected.len(), &plan)?; + assert_eq!(actual, expected); + + Ok(()) + } + + #[rstest] + #[case(0, 1024)] + #[case(0, 3000)] + #[case(0, 4096)] + #[case(500, 600)] + #[case(500, 1024)] + #[case(500, 2048)] + #[case(500, 4500)] + #[case(777, 3333)] + #[case(1024, 2048)] + #[case(1024, 4096)] + #[case(1500, 3500)] + #[case(2048, 4096)] + #[case(2500, 4500)] + #[case(3333, 4444)] + #[crate::test] + fn test_sliced_bitpacked( + #[case] slice_start: usize, + #[case] slice_end: usize, + ) -> VortexResult<()> { + let bit_width = 10u8; + let max_val = (1u32 << bit_width) - 1; + let len = 5000; + + let data: Vec = (0..len).map(|i| (i as u32) % max_val).collect(); + let prim = PrimitiveArray::new(Buffer::from(data.clone()), NonNullable); + let bp = BitPackedArray::encode(&prim.into_array(), bit_width)?; + + let sliced = bp.into_array().slice(slice_start..slice_end)?; + let expected: Vec = data[slice_start..slice_end].to_vec(); + + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let (plan, _bufs) = build_plan(&sliced, &cuda_ctx)?; + + let actual = run_dynamic_dispatch_plan(&cuda_ctx, expected.len(), &plan)?; + assert_eq!(actual, expected); + + Ok(()) + } + + #[rstest] + #[case(0, 1024)] + #[case(0, 3000)] + #[case(0, 4096)] + #[case(500, 600)] + #[case(500, 1024)] + #[case(500, 2048)] + #[case(500, 4500)] + #[case(777, 3333)] + #[case(1024, 2048)] + #[case(1024, 4096)] + #[case(1500, 3500)] + #[case(2048, 4096)] + #[case(2500, 4500)] + #[case(3333, 4444)] + #[crate::test] + fn test_sliced_for_bitpacked( + #[case] slice_start: usize, + #[case] slice_end: usize, + ) -> VortexResult<()> { + let reference = 100u32; + let bit_width = 10u8; + let max_val = (1u32 << bit_width) - 1; + let len = 5000; + + let encoded_data: Vec = (0..len).map(|i| (i as u32) % max_val).collect(); + let prim = PrimitiveArray::new(Buffer::from(encoded_data.clone()), NonNullable); + let bp = BitPackedArray::encode(&prim.into_array(), bit_width)?; + let for_arr = FoRArray::try_new(bp.into_array(), Scalar::from(reference))?; + + let all_decoded: Vec = encoded_data.iter().map(|&v| v + reference).collect(); + + let sliced = for_arr.into_array().slice(slice_start..slice_end)?; + let expected: Vec = all_decoded[slice_start..slice_end].to_vec(); + + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let (plan, _bufs) = build_plan(&sliced, &cuda_ctx)?; + + let actual = run_dynamic_dispatch_plan(&cuda_ctx, expected.len(), &plan)?; + assert_eq!(actual, expected); + + Ok(()) + } + + #[rstest] + #[case(0, 1024)] + #[case(0, 3000)] + #[case(0, 4096)] + #[case(400, 600)] + #[case(500, 1024)] + #[case(500, 2048)] + #[case(500, 4500)] + #[case(777, 3333)] + #[case(1024, 2048)] + #[case(1024, 4096)] + #[case(1500, 3500)] + #[case(2048, 4096)] + #[case(2500, 4500)] + #[case(3333, 4444)] + #[crate::test] + fn test_sliced_runend( + #[case] slice_start: usize, + #[case] slice_end: usize, + ) -> VortexResult<()> { + let ends: Vec = vec![500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]; + let values: Vec = vec![10, 20, 30, 40, 50, 60, 70, 80, 90, 100]; + let len = 5000; + + let all_decoded: Vec = (0..len) + .map(|i| { + let run = ends.iter().position(|&e| (i as u32) < e).unwrap(); + values[run] + }) + .collect(); + + let ends_arr = PrimitiveArray::new(Buffer::from(ends), NonNullable).into_array(); + let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array(); + let re = RunEndArray::new(ends_arr, values_arr); + + let sliced = re.into_array().slice(slice_start..slice_end)?; + let expected: Vec = all_decoded[slice_start..slice_end].to_vec(); + + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let (plan, _bufs) = build_plan(&sliced, &cuda_ctx)?; + + let actual = run_dynamic_dispatch_plan(&cuda_ctx, expected.len(), &plan)?; + assert_eq!(actual, expected); + + Ok(()) + } + + #[rstest] + #[case(0, 1024)] + #[case(0, 3000)] + #[case(0, 4096)] + #[case(500, 600)] + #[case(500, 1024)] + #[case(500, 2048)] + #[case(500, 4500)] + #[case(777, 3333)] + #[case(1024, 2048)] + #[case(1024, 4096)] + #[case(1500, 3500)] + #[case(2048, 4096)] + #[case(2500, 4500)] + #[case(3333, 4444)] + #[crate::test] + fn test_sliced_dict_for_bp_values_bp_codes( + #[case] slice_start: usize, + #[case] slice_end: usize, + ) -> VortexResult<()> { + let dict_reference = 1_000_000u32; + let dict_residuals: Vec = (0..64).collect(); + let dict_expected: Vec = dict_residuals.iter().map(|&r| r + dict_reference).collect(); + let dict_size = dict_residuals.len(); + + let len = 5000; + let codes: Vec = (0..len).map(|i| (i % dict_size) as u32).collect(); + let all_decoded: Vec = codes.iter().map(|&c| dict_expected[c as usize]).collect(); + + // BitPack+FoR the dict values + let dict_prim = PrimitiveArray::new(Buffer::from(dict_residuals), NonNullable); + let dict_bp = BitPackedArray::encode(&dict_prim.into_array(), 6)?; + let dict_for = FoRArray::try_new(dict_bp.into_array(), Scalar::from(dict_reference))?; + + // BitPack the codes + let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable); + let codes_bp = BitPackedArray::encode(&codes_prim.into_array(), 6)?; + + let dict = DictArray::try_new(codes_bp.into_array(), dict_for.into_array())?; + + let sliced = dict.into_array().slice(slice_start..slice_end)?; + let expected: Vec = all_decoded[slice_start..slice_end].to_vec(); + + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let (plan, _bufs) = build_plan(&sliced, &cuda_ctx)?; + + let actual = run_dynamic_dispatch_plan(&cuda_ctx, expected.len(), &plan)?; + assert_eq!(actual, expected); + + Ok(()) + } } diff --git a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs index 734e46b35fd..679fdcc6a0f 100644 --- a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs +++ b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs @@ -10,15 +10,19 @@ use futures::executor::block_on; use vortex::array::ArrayRef; use vortex::array::DynArray; +use vortex::array::ExecutionCtx; use vortex::array::arrays::DictVTable; use vortex::array::arrays::PrimitiveVTable; +use vortex::array::arrays::SliceVTable; use vortex::array::arrays::primitive::PrimitiveArrayParts; use vortex::array::buffer::BufferHandle; +use vortex::array::session::ArraySession; use vortex::dtype::PType; use vortex::encodings::alp::ALPFloat; use vortex::encodings::alp::ALPVTable; use vortex::encodings::fastlanes::BitPackedArrayParts; use vortex::encodings::fastlanes::BitPackedVTable; +use vortex::encodings::fastlanes::FoRArray; use vortex::encodings::fastlanes::FoRVTable; use vortex::encodings::runend::RunEndArrayParts; use vortex::encodings::runend::RunEndVTable; @@ -26,6 +30,7 @@ use vortex::encodings::zigzag::ZigZagVTable; use vortex::error::VortexResult; use vortex::error::vortex_bail; use vortex::error::vortex_err; +use vortex::session::VortexSession; use super::DynamicDispatchPlan; use super::MAX_SCALAR_OPS; @@ -77,6 +82,7 @@ struct Pipeline { /// - `ALPArray` → recurse + `ALP` scalar op (f32 only, no patches) /// - `DictArray` → input stage for values + recurse codes + `DICT` scalar op /// - `RunEndArray` → input stages for ends/values + `RUNEND` source +/// - `SliceArray` → resolve via child's slice reduce/kernel /// /// # Limitations /// @@ -84,8 +90,6 @@ struct Pipeline { /// receive a value regardless of whether the input was null. Only arrays with /// `NonNullable` or `AllValid` validity produce correct results. /// -/// **Slicing**: Not supported. -/// /// **Patches**: `BitPackedArray` with patches and `ALPArray` with patches are /// not supported and will return an error. /// @@ -152,6 +156,8 @@ impl PlanBuilderState<'_> { self.walk_runend(array) } else if id == PrimitiveVTable::ID { self.walk_primitive(array) + } else if id == SliceVTable::ID { + self.walk_slice(array) } else { vortex_bail!( "Encoding {:?} not supported by dynamic dispatch plan builder", @@ -160,7 +166,34 @@ impl PlanBuilderState<'_> { } } + /// SliceArray → resolve the slice via reduce/execute rules. + /// + /// When the plan builder encounters a `SliceArray`, it resolves the slice + /// by invoking the child's `reduce_parent`, `execute_parent`. + fn walk_slice(&mut self, array: ArrayRef) -> VortexResult { + let slice_arr = array.as_::(); + let child = slice_arr.child().clone(); + + // reduce_parent: (for types with SliceReduceAdaptor, like FoR/ZigZag) + if let Some(reduced) = child.vtable().reduce_parent(&child, &array, 0)? { + return self.walk(reduced); + } + + // execute_parent: (for types with SliceExecuteAdaptor/SliceKernel, like BitPacked) + let mut ctx = ExecutionCtx::new(VortexSession::empty().with::()); + if let Some(executed) = child.vtable().execute_parent(&child, &array, 0, &mut ctx)? { + return self.walk(executed); + } + + vortex_bail!( + "Cannot resolve SliceArray wrapping {:?} in dynamic dispatch plan builder", + child.encoding_id() + ) + } + /// Canonical primitive array → LOAD source op. + /// + /// The device pointer accounts for buffer slicing, so no offset parameter is needed. fn walk_primitive(&mut self, array: ArrayRef) -> VortexResult { let prim = array.to_canonical()?.into_primitive(); let PrimitiveArrayParts { buffer, .. } = prim.into_parts(); @@ -170,11 +203,14 @@ impl PlanBuilderState<'_> { Ok(Pipeline { source: SourceOp::load(), scalar_ops: vec![], - input_ptr: ptr, + input_ptr: ptr as u64, }) } /// BitPackedArray → BITUNPACK source op. + /// + /// The sub-byte element offset (0..=1023) is passed as a kernel parameter + /// as it cannot be expressed as pointer arithmetic on the device pointer. fn walk_bitpacked(&mut self, array: ArrayRef) -> VortexResult { let bp = array .try_into::() @@ -187,11 +223,6 @@ impl PlanBuilderState<'_> { .. } = bp.into_parts(); - if offset != 0 { - vortex_bail!( - "Dynamic dispatch does not support sliced BitPackedArray (offset={offset})" - ); - } if patches.is_some() { vortex_bail!("Dynamic dispatch does not support BitPackedArray with patches"); } @@ -200,9 +231,9 @@ impl PlanBuilderState<'_> { let ptr = device_buf.cuda_device_ptr()?; self.device_buffers.push(device_buf); Ok(Pipeline { - source: SourceOp::bitunpack(bit_width), + source: SourceOp::bitunpack(bit_width, offset), scalar_ops: vec![], - input_ptr: ptr, + input_ptr: ptr as u64, }) } @@ -313,7 +344,7 @@ impl PlanBuilderState<'_> { } /// Extract a FoR reference scalar as u64 bits. -fn extract_for_reference(for_arr: &vortex::encodings::fastlanes::FoRArray) -> VortexResult { +fn extract_for_reference(for_arr: &FoRArray) -> VortexResult { if let Ok(v) = u32::try_from(for_arr.reference_scalar()) { Ok(v as u64) } else if let Ok(v) = i32::try_from(for_arr.reference_scalar()) { diff --git a/vortex-cxx/CMakeLists.txt b/vortex-cxx/CMakeLists.txt index 557b8acdbd7..9d5a4eb02d2 100644 --- a/vortex-cxx/CMakeLists.txt +++ b/vortex-cxx/CMakeLists.txt @@ -8,7 +8,7 @@ cmake_policy(SET CMP0135 NEW) project(vortex) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_program(SCCACHE_PROGRAM sccache) @@ -72,17 +72,17 @@ set(CPP_PRIVATE_INCLUDE_DIRS # Create the main library combining C++ and Rust code add_library(vortex STATIC ${CPP_SOURCE_FILE}) -target_include_directories(vortex PUBLIC ${CPP_INCLUDE_DIRS} +target_include_directories(vortex PUBLIC ${CPP_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/corrosion_generated/cxxbridge/vortex_cxx_bridge/include) -target_include_directories(vortex PRIVATE +target_include_directories(vortex PRIVATE ${CPP_PRIVATE_INCLUDE_DIRS} ) -target_link_libraries(vortex +target_link_libraries(vortex PUBLIC nanoarrow_static PRIVATE vortex_cxx_bridge ) -if (VORTEX_ENABLE_ASAN) +if (VORTEX_ENABLE_ASAN) target_compile_options(vortex PRIVATE -fsanitize=leak,address,undefined -fno-omit-frame-pointer -fno-common -O1) target_link_options(vortex PRIVATE -fsanitize=leak,address,undefined) endif() @@ -102,7 +102,7 @@ if (VORTEX_ENABLE_TESTING) target_include_directories(vortex_cxx_test PUBLIC ${CPP_INCLUDE_DIRS}) target_include_directories(vortex_cxx_test PRIVATE cpp/tests) target_link_libraries(vortex_cxx_test PRIVATE gtest_main vortex nanoarrow_static) - target_include_directories(vortex_cxx_test PRIVATE + target_include_directories(vortex_cxx_test PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/corrosion_generated/cxxbridge/vortex_cxx_bridge/include ) # Platform-specific configuration @@ -110,7 +110,7 @@ if (VORTEX_ENABLE_TESTING) set(APPLE_LINK_FLAGS "-framework CoreFoundation -framework Security") endif() target_link_libraries(vortex_cxx_test PRIVATE vortex_cxx_bridge ${APPLE_LINK_FLAGS}) - if (VORTEX_ENABLE_ASAN) + if (VORTEX_ENABLE_ASAN) target_compile_options(vortex_cxx_test PRIVATE -fsanitize=leak,address,undefined -fno-omit-frame-pointer -fno-common -O1) target_link_options(vortex_cxx_test PRIVATE -fsanitize=leak,address,undefined) endif() diff --git a/vortex-cxx/README.md b/vortex-cxx/README.md index e413ec15142..aa761f98b1a 100644 --- a/vortex-cxx/README.md +++ b/vortex-cxx/README.md @@ -7,7 +7,7 @@ This directory contains C++ bindings for Vortex using the [cxx](https://cxx.rs/) ### Requirements - CMake 3.22 or higher -- C++17 compatible compiler +- C++20 compatible compiler - Rust toolchain (for building the Rust components) - (optional) Ninja (`ninja-build`) diff --git a/vortex-cxx/examples/CMakeLists.txt b/vortex-cxx/examples/CMakeLists.txt index 40bb44dde71..2da1b30f3e7 100644 --- a/vortex-cxx/examples/CMakeLists.txt +++ b/vortex-cxx/examples/CMakeLists.txt @@ -5,7 +5,7 @@ cmake_minimum_required(VERSION 3.22) project(vortex-examples) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) diff --git a/vortex-duckdb/README.md b/vortex-duckdb/README.md index 25fbdab509c..416bbac7fcb 100644 --- a/vortex-duckdb/README.md +++ b/vortex-duckdb/README.md @@ -6,7 +6,7 @@ Rust bindings for DuckDB. Supports DuckDB precompiled libraries for fast builds - **Ninja**: `brew install ninja` (macOS) | `apt-get install ninja-build` (Ubuntu) - **CMake**: `brew install cmake` (macOS) | `apt-get install cmake` (Ubuntu) -- **C++17 compatible compiler**: GCC or Clang +- **C++20 compatible compiler**: GCC or Clang ## Build Modes diff --git a/vortex-duckdb/build.rs b/vortex-duckdb/build.rs index 78d0b77f097..0b7a53bdaea 100644 --- a/vortex-duckdb/build.rs +++ b/vortex-duckdb/build.rs @@ -499,15 +499,11 @@ fn main() { // Compile our C++ code that exposes additional DuckDB functionality. cc::Build::new() - .std("c++17") + .std("c++20") // Enable compiler warnings. .flag("-Wall") .flag("-Wextra") .flag("-Wpedantic") - // Allow C++20 designator syntax even with C++17 std - .flag("-Wno-c++20-designator") - // Enable C++20 extensions - .flag("-Wno-c++20-extensions") // Unused parameter warnings are disabled as we include DuckDB // headers with implementations that have unused parameters. .flag("-Wno-unused-parameter") diff --git a/vortex-duckdb/cpp/CMakeLists.txt b/vortex-duckdb/cpp/CMakeLists.txt index 33de88deb2b..e0a40c4c918 100644 --- a/vortex-duckdb/cpp/CMakeLists.txt +++ b/vortex-duckdb/cpp/CMakeLists.txt @@ -14,7 +14,7 @@ cmake_minimum_required(VERSION 3.20) project(vortex_duckdb_cpp) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Default to debug if build config is not explicitly set. @@ -23,7 +23,7 @@ if (NOT CMAKE_BUILD_TYPE) endif () # Enable compiler warnings (matching build.rs flags). -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wpedantic -Wno-unused-parameter -Wno-c++20-designator -Wno-c++20-extensions") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wpedantic -Wno-unused-parameter") # Find DuckDB include directory via the symlink created by build.rs. # The symlink points to target/duckdb-source-vX.Y.Z which contains duckdb-X.Y.Z/ diff --git a/vortex-duckdb/src/convert/dtype.rs b/vortex-duckdb/src/convert/dtype.rs index aeffa40140b..7a96b66c729 100644 --- a/vortex-duckdb/src/convert/dtype.rs +++ b/vortex-duckdb/src/convert/dtype.rs @@ -591,18 +591,13 @@ mod tests { Ok(EmptyMetadata) } - fn validate_dtype( - &self, - _options: &Self::Metadata, - _storage_dtype: &DType, - ) -> VortexResult<()> { + fn validate_dtype(&self, _ext_dtype: &ExtDType) -> VortexResult<()> { Ok(()) } fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _ext_dtype: &'a ExtDType, _storage_value: &'a ScalarValue, ) -> VortexResult> { Ok("") diff --git a/vortex-jni/src/array.rs b/vortex-jni/src/array.rs index 3d41ad1e220..3fcdf041cb5 100644 --- a/vortex-jni/src/array.rs +++ b/vortex-jni/src/array.rs @@ -294,7 +294,7 @@ macro_rules! get_primitive { array_ref .inner .to_extension() - .storage() + .storage_array() .scalar_at(index as usize)? } else { array_ref.inner.scalar_at(index as usize)? @@ -333,7 +333,7 @@ pub extern "system" fn Java_dev_vortex_jni_NativeArrayMethods_getBigDecimal( array_ref .inner .to_extension() - .storage() + .storage_array() .scalar_at(index as usize)? } else { array_ref.inner.scalar_at(index as usize)? diff --git a/vortex-layout/src/layouts/row_idx/mod.rs b/vortex-layout/src/layouts/row_idx/mod.rs index fd7513aad96..8ce4a2f5f4b 100644 --- a/vortex-layout/src/layouts/row_idx/mod.rs +++ b/vortex-layout/src/layouts/row_idx/mod.rs @@ -9,6 +9,7 @@ use std::fmt::Formatter; use std::ops::BitAnd; use std::ops::Range; use std::sync::Arc; +use std::sync::OnceLock; use Nullability::NonNullable; pub use expr::*; @@ -47,7 +48,7 @@ pub struct RowIdxLayoutReader { name: Arc, row_offset: u64, child: Arc, - partition_cache: DashMap, + partition_cache: DashMap>>, session: VortexSession, } @@ -66,45 +67,52 @@ impl RowIdxLayoutReader { let key = ExactExpr(expr.clone()); // Check cache first with read-only lock. - if let Some(partitioning) = self.partition_cache.get(&key) { + if let Some(entry) = self.partition_cache.get(&key) + && let Some(partitioning) = entry.value().get() + { return partitioning.clone(); } - self.partition_cache + let cell = self + .partition_cache .entry(key) - .or_insert_with(|| { - // Partition the expression into row idx and child expressions. - let mut partitioned = partition(expr.clone(), self.dtype(), |expr| { - if expr.is::() { - vec![Partition::RowIdx] - } else if is_root(expr) { - vec![Partition::Child] - } else { - vec![] - } - }) - .vortex_expect("We should not fail to partition expression over struct fields"); + .or_insert_with(|| Arc::new(OnceLock::new())) + .clone(); - // If there's only a single partition, we can directly return the expression. - if partitioned.partitions.len() == 1 { - return match &partitioned.partition_annotations[0] { - Partition::RowIdx => { - Partitioning::RowIdx(replace(expr.clone(), &row_idx(), root())) - } - Partition::Child => Partitioning::Child(expr.clone()), - }; + cell.get_or_init(|| self.compute_partitioning(expr)).clone() + } + + fn compute_partitioning(&self, expr: &Expression) -> Partitioning { + // Partition the expression into row idx and child expressions. + let mut partitioned = partition(expr.clone(), self.dtype(), |expr| { + if expr.is::() { + vec![Partition::RowIdx] + } else if is_root(expr) { + vec![Partition::Child] + } else { + vec![] + } + }) + .vortex_expect("We should not fail to partition expression over struct fields"); + + // If there's only a single partition, we can directly return the expression. + if partitioned.partitions.len() == 1 { + return match &partitioned.partition_annotations[0] { + Partition::RowIdx => { + Partitioning::RowIdx(replace(expr.clone(), &row_idx(), root())) } + Partition::Child => Partitioning::Child(expr.clone()), + }; + } - // Replace the row_idx expression with the root expression in the row_idx partition. - partitioned.partitions = partitioned - .partitions - .into_iter() - .map(|p| replace(p, &row_idx(), root())) - .collect(); + // Replace the row_idx expression with the root expression in the row_idx partition. + partitioned.partitions = partitioned + .partitions + .into_iter() + .map(|p| replace(p, &row_idx(), root())) + .collect(); - Partitioning::Partitioned(Arc::new(partitioned)) - }) - .clone() + Partitioning::Partitioned(Arc::new(partitioned)) } } diff --git a/vortex-layout/src/layouts/struct_/reader.rs b/vortex-layout/src/layouts/struct_/reader.rs index fc665731bdd..bb069df9723 100644 --- a/vortex-layout/src/layouts/struct_/reader.rs +++ b/vortex-layout/src/layouts/struct_/reader.rs @@ -4,6 +4,7 @@ use std::collections::BTreeSet; use std::ops::Range; use std::sync::Arc; +use std::sync::OnceLock; use futures::try_join; use itertools::Itertools; @@ -57,7 +58,7 @@ pub struct StructReader { expanded_root_expr: Expression, field_lookup: Option>, - partitioned_expr_cache: DashMap, + partitioned_expr_cache: DashMap>>, } impl StructReader { @@ -152,51 +153,65 @@ impl StructReader { /// Utility for partitioning an expression over the fields of a struct. fn partition_expr(&self, expr: Expression) -> Partitioned { - self.partitioned_expr_cache - .entry(ExactExpr(expr.clone())) - .or_insert_with(|| { - // First, we expand the root scope into the fields of the struct to ensure - // that partitioning works correctly. - let expr = replace(expr.clone(), &root(), self.expanded_root_expr.clone()); - let expr = expr - .optimize_recursive(self.dtype()) - .vortex_expect("We should not fail to simplify expression over struct fields"); - - // Partition the expression into expressions that can be evaluated over individual fields - let mut partitioned = partition( - expr.clone(), - self.dtype(), - make_free_field_annotator( - self.dtype() - .as_struct_fields_opt() - .vortex_expect("We know it's a struct DType"), - ), - ) - .vortex_expect("We should not fail to partition expression over struct fields"); - - if partitioned.partitions.len() == 1 { - // If there's only one partition, we step into the field scope of the original - // expression by replacing any `$.a` with `$`. - return Partitioned::Single( - partitioned.partition_names[0].clone(), - replace(expr, &col(partitioned.partition_names[0].clone()), root()), - ); - } + let key = ExactExpr(expr.clone()); + + if let Some(entry) = self.partitioned_expr_cache.get(&key) + && let Some(partitioning) = entry.value().get() + { + return partitioning.clone(); + } - // We now need to process the partitioned expressions to rewrite the root scope - // to be that of the field, rather than the struct. In other words, "stepping in" - // to the field scope. - partitioned.partitions = partitioned - .partitions - .iter() - .zip_eq(partitioned.partition_names.iter()) - .map(|(e, name)| replace(e.clone(), &col(name.clone()), root())) - .collect(); - - Partitioned::Multi(Arc::new(partitioned)) - }) + let cell = self + .partitioned_expr_cache + .entry(key) + .or_insert_with(|| Arc::new(OnceLock::new())) + .clone(); + + cell.get_or_init(|| self.compute_partitioned_expr(expr)) .clone() } + + fn compute_partitioned_expr(&self, expr: Expression) -> Partitioned { + // First, we expand the root scope into the fields of the struct to ensure + // that partitioning works correctly. + let expr = replace(expr, &root(), self.expanded_root_expr.clone()); + let expr = expr + .optimize_recursive(self.dtype()) + .vortex_expect("We should not fail to simplify expression over struct fields"); + + // Partition the expression into expressions that can be evaluated over individual fields + let mut partitioned = partition( + expr.clone(), + self.dtype(), + make_free_field_annotator( + self.dtype() + .as_struct_fields_opt() + .vortex_expect("We know it's a struct DType"), + ), + ) + .vortex_expect("We should not fail to partition expression over struct fields"); + + if partitioned.partitions.len() == 1 { + // If there's only one partition, we step into the field scope of the original + // expression by replacing any `$.a` with `$`. + return Partitioned::Single( + partitioned.partition_names[0].clone(), + replace(expr, &col(partitioned.partition_names[0].clone()), root()), + ); + } + + // We now need to process the partitioned expressions to rewrite the root scope + // to be that of the field, rather than the struct. In other words, "stepping in" + // to the field scope. + partitioned.partitions = partitioned + .partitions + .iter() + .zip_eq(partitioned.partition_names.iter()) + .map(|(e, name)| replace(e.clone(), &col(name.clone()), root())) + .collect(); + + Partitioned::Multi(Arc::new(partitioned)) + } } /// When partitioning an expression, in the case it only has a single partition we can avoid diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index 1f87eb70f46..bf60d6fa5d6 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -102,10 +102,11 @@ where impl Eq for AllOr where T: Eq {} /// Represents a set of sorted unique positive integers. +/// If a value is included in a Mask, it's valid. /// /// A [`Mask`] can be constructed from various representations, and converted to various /// others. Internally, these are cached. -#[derive(Debug, Clone)] +#[derive(Clone)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] pub enum Mask { /// All values are included. @@ -116,6 +117,16 @@ pub enum Mask { Values(Arc), } +impl Debug for Mask { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::AllTrue(len) => write!(f, "All true({len})"), + Self::AllFalse(len) => write!(f, "All false({len})"), + Self::Values(mask) => write!(f, "{mask:?}"), + } + } +} + impl Default for Mask { fn default() -> Self { Self::new_true(0) @@ -123,7 +134,6 @@ impl Default for Mask { } /// Represents the values of a [`Mask`] that contains some true and some false elements. -#[derive(Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct MaskValues { buffer: BitBuffer, @@ -141,6 +151,23 @@ pub struct MaskValues { density: f64, } +impl Debug for MaskValues { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "true_count={}, ", self.true_count)?; + write!(f, "density={}, ", self.density)?; + if let Some(v) = self.indices.get() { + write!(f, "indices={v:?}, ")?; + } + if let Some(v) = self.slices.get() { + write!(f, "slices={v:?}, ")?; + } + if f.alternate() { + f.write_str("\n")?; + } + write!(f, "{}", self.buffer) + } +} + impl Mask { /// Create a new Mask with the given length. pub fn new(length: usize, value: bool) -> Self { diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 3046784af96..2104eb97c95 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -1,121 +1,123 @@ pub mod vortex_tensor -pub mod vortex_tensor::scalar_fns +pub mod vortex_tensor::fixed_shape -pub mod vortex_tensor::scalar_fns::cosine_similarity +pub struct vortex_tensor::fixed_shape::FixedShapeTensor -pub struct vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity +impl core::clone::Clone for vortex_tensor::fixed_shape::FixedShapeTensor -impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::clone(&self) -> vortex_tensor::fixed_shape::FixedShapeTensor -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity +impl core::cmp::Eq for vortex_tensor::fixed_shape::FixedShapeTensor -impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity +impl core::cmp::PartialEq for vortex_tensor::fixed_shape::FixedShapeTensor -pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::eq(&self, other: &vortex_tensor::fixed_shape::FixedShapeTensor) -> bool -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity +impl core::default::Default for vortex_tensor::fixed_shape::FixedShapeTensor -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::default() -> vortex_tensor::fixed_shape::FixedShapeTensor -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +impl core::fmt::Debug for vortex_tensor::fixed_shape::FixedShapeTensor -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::id(&self) -> vortex_array::scalar_fn::ScalarFnId +impl core::hash::Hash for vortex_tensor::fixed_shape::FixedShapeTensor -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_fallible(&self, _options: &Self::Options) -> bool +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::hash<__H: core::hash::Hasher>(&self, state: &mut __H) -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_null_sensitive(&self, _options: &Self::Options) -> bool +impl core::marker::StructuralPartialEq for vortex_tensor::fixed_shape::FixedShapeTensor -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult +impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::fixed_shape::FixedShapeTensor -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> +pub type vortex_tensor::fixed_shape::FixedShapeTensor::Metadata = vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub struct vortex_tensor::FixedShapeTensor +pub type vortex_tensor::fixed_shape::FixedShapeTensor::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue -impl core::clone::Clone for vortex_tensor::FixedShapeTensor +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult -pub fn vortex_tensor::FixedShapeTensor::clone(&self) -> vortex_tensor::FixedShapeTensor +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::id(&self) -> vortex_array::dtype::extension::ExtId -impl core::cmp::Eq for vortex_tensor::FixedShapeTensor +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> -impl core::cmp::PartialEq for vortex_tensor::FixedShapeTensor +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::unpack_native<'a>(&self, _ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult -pub fn vortex_tensor::FixedShapeTensor::eq(&self, other: &vortex_tensor::FixedShapeTensor) -> bool +pub fn vortex_tensor::fixed_shape::FixedShapeTensor::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::typed::ExtDType) -> vortex_error::VortexResult<()> -impl core::default::Default for vortex_tensor::FixedShapeTensor +pub struct vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub fn vortex_tensor::FixedShapeTensor::default() -> vortex_tensor::FixedShapeTensor +impl vortex_tensor::fixed_shape::FixedShapeTensorMetadata -impl core::fmt::Debug for vortex_tensor::FixedShapeTensor +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::dim_names(&self) -> core::option::Option<&[alloc::string::String]> -pub fn vortex_tensor::FixedShapeTensor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::logical_shape(&self) -> &[usize] -impl core::hash::Hash for vortex_tensor::FixedShapeTensor +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::ndim(&self) -> usize -pub fn vortex_tensor::FixedShapeTensor::hash<__H: core::hash::Hasher>(&self, state: &mut __H) +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::new(shape: alloc::vec::Vec) -> Self -impl core::marker::StructuralPartialEq for vortex_tensor::FixedShapeTensor +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::permutation(&self) -> core::option::Option<&[usize]> -impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::FixedShapeTensor +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::physical_shape(&self) -> impl core::iter::traits::iterator::Iterator + '_ -pub type vortex_tensor::FixedShapeTensor::Metadata = vortex_tensor::FixedShapeTensorMetadata +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::strides(&self) -> impl core::iter::traits::iterator::Iterator + '_ -pub type vortex_tensor::FixedShapeTensor::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::with_dim_names(self, names: alloc::vec::Vec) -> vortex_error::VortexResult -pub fn vortex_tensor::FixedShapeTensor::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::with_permutation(self, permutation: alloc::vec::Vec) -> vortex_error::VortexResult -pub fn vortex_tensor::FixedShapeTensor::id(&self) -> vortex_array::dtype::extension::ExtId +impl core::clone::Clone for vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub fn vortex_tensor::FixedShapeTensor::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::clone(&self) -> vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub fn vortex_tensor::FixedShapeTensor::unpack_native<'a>(&self, _metadata: &'a Self::Metadata, _storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult +impl core::cmp::Eq for vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub fn vortex_tensor::FixedShapeTensor::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> +impl core::cmp::PartialEq for vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub struct vortex_tensor::FixedShapeTensorMetadata +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::eq(&self, other: &vortex_tensor::fixed_shape::FixedShapeTensorMetadata) -> bool -impl vortex_tensor::FixedShapeTensorMetadata +impl core::fmt::Debug for vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub fn vortex_tensor::FixedShapeTensorMetadata::dim_names(&self) -> core::option::Option<&[alloc::string::String]> +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -pub fn vortex_tensor::FixedShapeTensorMetadata::logical_shape(&self) -> &[usize] +impl core::fmt::Display for vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub fn vortex_tensor::FixedShapeTensorMetadata::ndim(&self) -> usize +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -pub fn vortex_tensor::FixedShapeTensorMetadata::new(shape: alloc::vec::Vec) -> Self +impl core::hash::Hash for vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub fn vortex_tensor::FixedShapeTensorMetadata::permutation(&self) -> core::option::Option<&[usize]> +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::hash<__H: core::hash::Hasher>(&self, state: &mut __H) -pub fn vortex_tensor::FixedShapeTensorMetadata::physical_shape(&self) -> impl core::iter::traits::iterator::Iterator + '_ +impl core::marker::StructuralPartialEq for vortex_tensor::fixed_shape::FixedShapeTensorMetadata -pub fn vortex_tensor::FixedShapeTensorMetadata::strides(&self) -> impl core::iter::traits::iterator::Iterator + '_ +pub mod vortex_tensor::scalar_fns -pub fn vortex_tensor::FixedShapeTensorMetadata::with_dim_names(self, names: alloc::vec::Vec) -> vortex_error::VortexResult +pub mod vortex_tensor::scalar_fns::cosine_similarity + +pub struct vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity -pub fn vortex_tensor::FixedShapeTensorMetadata::with_permutation(self, permutation: alloc::vec::Vec) -> vortex_error::VortexResult +impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity -impl core::clone::Clone for vortex_tensor::FixedShapeTensorMetadata +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity -pub fn vortex_tensor::FixedShapeTensorMetadata::clone(&self) -> vortex_tensor::FixedShapeTensorMetadata +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity -impl core::cmp::Eq for vortex_tensor::FixedShapeTensorMetadata +pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions -impl core::cmp::PartialEq for vortex_tensor::FixedShapeTensorMetadata +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity -pub fn vortex_tensor::FixedShapeTensorMetadata::eq(&self, other: &vortex_tensor::FixedShapeTensorMetadata) -> bool +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -impl core::fmt::Debug for vortex_tensor::FixedShapeTensorMetadata +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl core::fmt::Display for vortex_tensor::FixedShapeTensorMetadata +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::id(&self) -> vortex_array::scalar_fn::ScalarFnId -pub fn vortex_tensor::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_fallible(&self, _options: &Self::Options) -> bool -impl core::hash::Hash for vortex_tensor::FixedShapeTensorMetadata +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_null_sensitive(&self, _options: &Self::Options) -> bool -pub fn vortex_tensor::FixedShapeTensorMetadata::hash<__H: core::hash::Hasher>(&self, state: &mut __H) +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult -impl core::marker::StructuralPartialEq for vortex_tensor::FixedShapeTensorMetadata +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> diff --git a/vortex-tensor/src/metadata.rs b/vortex-tensor/src/fixed_shape/metadata.rs similarity index 100% rename from vortex-tensor/src/metadata.rs rename to vortex-tensor/src/fixed_shape/metadata.rs diff --git a/vortex-tensor/src/fixed_shape/mod.rs b/vortex-tensor/src/fixed_shape/mod.rs new file mode 100644 index 00000000000..1c2e8b801c7 --- /dev/null +++ b/vortex-tensor/src/fixed_shape/mod.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Fixed-shape Tensor extension type. + +/// The VTable for the Tensor extension type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct FixedShapeTensor; + +mod metadata; +pub use metadata::FixedShapeTensorMetadata; + +mod proto; +mod vtable; diff --git a/vortex-tensor/src/proto.rs b/vortex-tensor/src/fixed_shape/proto.rs similarity index 98% rename from vortex-tensor/src/proto.rs rename to vortex-tensor/src/fixed_shape/proto.rs index f454531fcca..06b4f45b726 100644 --- a/vortex-tensor/src/proto.rs +++ b/vortex-tensor/src/fixed_shape/proto.rs @@ -8,7 +8,7 @@ use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_err; -use crate::FixedShapeTensorMetadata; +use crate::fixed_shape::FixedShapeTensorMetadata; /// Protobuf representation of [`FixedShapeTensorMetadata`]. /// diff --git a/vortex-tensor/src/vtable.rs b/vortex-tensor/src/fixed_shape/vtable.rs similarity index 87% rename from vortex-tensor/src/vtable.rs rename to vortex-tensor/src/fixed_shape/vtable.rs index ecec816516a..15e47456ba9 100644 --- a/vortex-tensor/src/vtable.rs +++ b/vortex-tensor/src/fixed_shape/vtable.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex::dtype::DType; +use vortex::dtype::extension::ExtDType; use vortex::dtype::extension::ExtId; use vortex::dtype::extension::ExtVTable; use vortex::error::VortexResult; @@ -10,9 +11,9 @@ use vortex::error::vortex_ensure; use vortex::error::vortex_ensure_eq; use vortex::scalar::ScalarValue; -use crate::FixedShapeTensor; -use crate::FixedShapeTensorMetadata; -use crate::proto; +use crate::fixed_shape::FixedShapeTensor; +use crate::fixed_shape::FixedShapeTensorMetadata; +use crate::fixed_shape::proto; impl ExtVTable for FixedShapeTensor { type Metadata = FixedShapeTensorMetadata; @@ -32,7 +33,8 @@ impl ExtVTable for FixedShapeTensor { proto::deserialize(metadata) } - fn validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()> { + fn validate_dtype(&self, ext_dtype: &ExtDType) -> VortexResult<()> { + let storage_dtype = ext_dtype.storage_dtype(); let DType::FixedSizeList(element_dtype, list_size, _nullability) = storage_dtype else { vortex_bail!( "FixedShapeTensor storage dtype must be a FixedSizeList, got {storage_dtype}" @@ -50,7 +52,7 @@ impl ExtVTable for FixedShapeTensor { "FixedShapeTensor element dtype must be non-nullable (may change in the future)" ); - let element_count: usize = metadata.logical_shape().iter().product(); + let element_count: usize = ext_dtype.metadata().logical_shape().iter().product(); vortex_ensure_eq!( element_count, *list_size as usize, @@ -63,8 +65,7 @@ impl ExtVTable for FixedShapeTensor { fn unpack_native<'a>( &self, - _metadata: &'a Self::Metadata, - _storage_dtype: &'a DType, + _ext_dtype: &'a ExtDType, storage_value: &'a ScalarValue, ) -> VortexResult> { // TODO(connor): This is just a placeholder. However, even if we have a dedicated native @@ -80,8 +81,8 @@ mod tests { use vortex::dtype::extension::ExtVTable; use vortex::error::VortexResult; - use crate::FixedShapeTensor; - use crate::FixedShapeTensorMetadata; + use crate::fixed_shape::FixedShapeTensor; + use crate::fixed_shape::FixedShapeTensorMetadata; /// Serializes and deserializes the given metadata through protobuf, asserting equality. fn assert_roundtrip(metadata: &FixedShapeTensorMetadata) -> VortexResult<()> { diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index ab18826c6b6..dc33066bd3b 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -1,16 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Tensor extension type. +//! Types and functionality for working with tensors, vectors, and related mathematical constructs +//! including unit vectors, spherical coordinates, and similarity measures such as cosine +//! similarity. -mod metadata; -pub use metadata::FixedShapeTensorMetadata; - -mod proto; -mod vtable; +pub mod fixed_shape; pub mod scalar_fns; - -/// The VTable for the Tensor extension type. -#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] -pub struct FixedShapeTensor; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 1746e6a5a75..4bcc6f3b41b 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -1,7 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Cosine similarity expression for [`FixedShapeTensor`](crate::FixedShapeTensor) arrays. +//! Cosine similarity expression for [`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) +//! arrays. use std::fmt::Formatter; @@ -30,16 +31,18 @@ use vortex::scalar_fn::ExecutionArgs; use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; -/// Cosine similarity between two [`FixedShapeTensor`] columns. +// TODO(connor): We will want to add implementations for unit normalized vectors and also vectors +// encoded in spherical coordinates. +/// Cosine similarity between two columns. /// -/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor. The -/// shape and permutation do not affect the result because cosine similarity only depends on the -/// element values, not their logical arrangement. +/// For [`FixedShapeTensor`], computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of +/// each tensor. The shape and permutation do not affect the result because cosine similarity only +/// depends on the element values, not their logical arrangement. /// -/// Both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a float -/// element type (`f32` or `f64`). The output is a primitive column of the same float type. +/// Right now, both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a +/// float element type. The output is a float column of the same float type. /// -/// [`FixedShapeTensor`]: crate::FixedShapeTensor +/// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor #[derive(Clone)] pub struct CosineSimilarity; @@ -196,7 +199,7 @@ fn extension_storage(array: &ArrayRef) -> VortexResult { let ext = array .as_opt::() .ok_or_else(|| vortex_err!("cosine_similarity input must be an extension array"))?; - Ok(ext.storage().clone()) + Ok(ext.storage_array().clone()) } /// Extracts the flat primitive elements from a tensor storage array (FixedSizeList). @@ -259,8 +262,8 @@ mod tests { use vortex::scalar_fn::EmptyOptions; use vortex::scalar_fn::ScalarFn; - use crate::FixedShapeTensor; - use crate::FixedShapeTensorMetadata; + use crate::fixed_shape::FixedShapeTensor; + use crate::fixed_shape::FixedShapeTensorMetadata; use crate::scalar_fns::cosine_similarity::CosineSimilarity; /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. diff --git a/vortex/benches/common_encoding_tree_throughput.rs b/vortex/benches/common_encoding_tree_throughput.rs index f4243a7d332..b60a6f87bd2 100644 --- a/vortex/benches/common_encoding_tree_throughput.rs +++ b/vortex/benches/common_encoding_tree_throughput.rs @@ -44,7 +44,7 @@ fn main() { divan::main(); } -const NUM_VALUES: u64 = 1_000_000; +const NUM_VALUES: u64 = 100_000; // Helper function to conditionally add counter based on codspeed cfg fn with_byte_counter<'a, 'b>(bencher: Bencher<'a, 'b>, bytes: u64) -> Bencher<'a, 'b> { diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index dd22c30b996..50ceaf7fbdb 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors #![allow(clippy::unwrap_used)] +#![allow(clippy::cast_possible_truncation)] #![allow(unexpected_cfgs)] use divan::Bencher; @@ -41,7 +42,7 @@ fn main() { divan::main(); } -const NUM_VALUES: u64 = 1_000_000; +const NUM_VALUES: u64 = 100_000; // Helper function to conditionally add counter based on codspeed cfg fn with_byte_counter<'a, 'b>(bencher: Bencher<'a, 'b>, bytes: u64) -> Bencher<'a, 'b> { @@ -326,7 +327,8 @@ fn bench_zstd_decompress_u32(bencher: Bencher) { // String compression benchmarks #[divan::bench(name = "dict_compress_string")] fn bench_dict_compress_string(bencher: Bencher) { - let varbinview_arr = VarBinViewArray::from_iter_str(gen_varbin_words(1_000_000, 0.00005)); + let varbinview_arr = + VarBinViewArray::from_iter_str(gen_varbin_words(NUM_VALUES as usize, 0.00005)); let nbytes = varbinview_arr.nbytes() as u64; with_byte_counter(bencher, nbytes) @@ -336,7 +338,8 @@ fn bench_dict_compress_string(bencher: Bencher) { #[divan::bench(name = "dict_decompress_string")] fn bench_dict_decompress_string(bencher: Bencher) { - let varbinview_arr = VarBinViewArray::from_iter_str(gen_varbin_words(1_000_000, 0.00005)); + let varbinview_arr = + VarBinViewArray::from_iter_str(gen_varbin_words(NUM_VALUES as usize, 0.00005)); let dict = dict_encode(&varbinview_arr.clone().into_array()).unwrap(); let nbytes = varbinview_arr.into_array().nbytes() as u64; @@ -347,7 +350,8 @@ fn bench_dict_decompress_string(bencher: Bencher) { #[divan::bench(name = "fsst_compress_string")] fn bench_fsst_compress_string(bencher: Bencher) { - let varbinview_arr = VarBinViewArray::from_iter_str(gen_varbin_words(1_000_000, 0.00005)); + let varbinview_arr = + VarBinViewArray::from_iter_str(gen_varbin_words(NUM_VALUES as usize, 0.00005)); let fsst_compressor = fsst_train_compressor(&varbinview_arr); let nbytes = varbinview_arr.nbytes() as u64; @@ -358,7 +362,8 @@ fn bench_fsst_compress_string(bencher: Bencher) { #[divan::bench(name = "fsst_decompress_string")] fn bench_fsst_decompress_string(bencher: Bencher) { - let varbinview_arr = VarBinViewArray::from_iter_str(gen_varbin_words(1_000_000, 0.00005)); + let varbinview_arr = + VarBinViewArray::from_iter_str(gen_varbin_words(NUM_VALUES as usize, 0.00005)); let fsst_compressor = fsst_train_compressor(&varbinview_arr); let fsst_array = fsst_compress(&varbinview_arr, &fsst_compressor); let nbytes = varbinview_arr.into_array().nbytes() as u64; @@ -371,7 +376,8 @@ fn bench_fsst_decompress_string(bencher: Bencher) { #[cfg(feature = "zstd")] #[divan::bench(name = "zstd_compress_string")] fn bench_zstd_compress_string(bencher: Bencher) { - let varbinview_arr = VarBinViewArray::from_iter_str(gen_varbin_words(1_000_000, 0.00005)); + let varbinview_arr = + VarBinViewArray::from_iter_str(gen_varbin_words(NUM_VALUES as usize, 0.00005)); let nbytes = varbinview_arr.nbytes() as u64; let array = varbinview_arr.into_array(); @@ -383,7 +389,8 @@ fn bench_zstd_compress_string(bencher: Bencher) { #[cfg(feature = "zstd")] #[divan::bench(name = "zstd_decompress_string")] fn bench_zstd_decompress_string(bencher: Bencher) { - let varbinview_arr = VarBinViewArray::from_iter_str(gen_varbin_words(1_000_000, 0.00005)); + let varbinview_arr = + VarBinViewArray::from_iter_str(gen_varbin_words(NUM_VALUES as usize, 0.00005)); let compressed = ZstdArray::from_array(varbinview_arr.clone().into_array(), 3, 8192).unwrap(); let nbytes = varbinview_arr.into_array().nbytes() as u64;